diff --git a/docs/api/functions.md b/docs/api/functions.md index 82c8f48..fbbfe04 100644 --- a/docs/api/functions.md +++ b/docs/api/functions.md @@ -56,4 +56,8 @@ Note that these do *not* inspect the values of the operator -- instead, they use --- +::: lineax.is_hermitian + +--- + ::: lineax.max_rank diff --git a/docs/api/operators.md b/docs/api/operators.md index 3180df1..a924521 100644 --- a/docs/api/operators.md +++ b/docs/api/operators.md @@ -16,6 +16,7 @@ Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such tha - mv - as_matrix - transpose + - H - in_structure - out_structure - in_size diff --git a/docs/api/solvers.md b/docs/api/solvers.md index b3159d9..d57ec6a 100644 --- a/docs/api/solvers.md +++ b/docs/api/solvers.md @@ -44,6 +44,13 @@ These are capable of solving ill-posed linear problems. --- +::: lineax.HEVD + options: + members: + - __init__ + +--- + ::: lineax.Normal options: members: diff --git a/docs/api/tags.md b/docs/api/tags.md index 18f582e..c26a8d4 100644 --- a/docs/api/tags.md +++ b/docs/api/tags.md @@ -68,7 +68,13 @@ lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True ::: lineax.symmetric_tag -Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) +Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) For real operators this coincides with `lineax.hermitian_tag`. + +--- + +::: lineax.hermitian_tag + +Marks that an operator is Hermitian (self-adjoint). (As a matrix, $A = A^*$, the conjugate transpose.) For real operators this coincides with `lineax.symmetric_tag`. --- diff --git a/lineax/__init__.py b/lineax/__init__.py index 4f370de..1083400 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -27,6 +27,7 @@ has_unit_diagonal as has_unit_diagonal, IdentityLinearOperator as IdentityLinearOperator, is_diagonal as is_diagonal, + is_hermitian as is_hermitian, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, is_positive_semidefinite as is_positive_semidefinite, @@ -49,16 +50,17 @@ from ._solution import RESULTS as RESULTS, Solution as Solution from ._solve import ( AbstractLinearSolver as AbstractLinearSolver, - AutoLinearSolver as AutoLinearSolver, invert as invert, linear_solve as linear_solve, ) from ._solver import ( + AutoLinearSolver as AutoLinearSolver, BiCGStab as BiCGStab, CG as CG, Cholesky as Cholesky, Diagonal as Diagonal, GMRES as GMRES, + HEVD as HEVD, LSMR as LSMR, LU as LU, Normal as Normal, @@ -70,6 +72,7 @@ ) from ._tags import ( diagonal_tag as diagonal_tag, + hermitian_tag as hermitian_tag, invert_tags as invert_tags, invert_tags_rules as invert_tags_rules, lower_triangular_tag as lower_triangular_tag, diff --git a/lineax/_operator/__init__.py b/lineax/_operator/__init__.py index 5a61380..b4493e6 100644 --- a/lineax/_operator/__init__.py +++ b/lineax/_operator/__init__.py @@ -18,6 +18,7 @@ diagonal as diagonal, has_unit_diagonal as has_unit_diagonal, is_diagonal as is_diagonal, + is_hermitian as is_hermitian, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, is_positive_semidefinite as is_positive_semidefinite, diff --git a/lineax/_operator/base.py b/lineax/_operator/base.py index f2d8593..470aebf 100644 --- a/lineax/_operator/base.py +++ b/lineax/_operator/base.py @@ -87,6 +87,7 @@ class AbstractLinearOperator(eqx.Module): def __check_init__(self): if ( is_symmetric(self) + or is_hermitian(self) or is_positive_semidefinite(self) or is_negative_semidefinite(self) ): @@ -197,6 +198,19 @@ def T(self) -> "AbstractLinearOperator": """Equivalent to [`lineax.AbstractLinearOperator.transpose`][]""" return self.transpose() + @property + def H(self) -> "AbstractLinearOperator": + """The conjugate transpose (Hermitian adjoint) `Aᴴ` of this operator. + + Equivalent to `lineax.conj(operator).transpose()`, except that for a Hermitian + operator -- for which `Aᴴ = A` -- this is a no-op and returns `self` unchanged. + (As with [`lineax.is_hermitian`][], only the tag is checked, not the actual + values of the operator.) + """ + if is_hermitian(self): + return self + return conj(self).transpose() + def __add__(self, other) -> "AbstractLinearOperator": # Local imports to avoid a circular dependency: `binary`/`wrapper` import # `base`, so the operators built here are imported lazily at call time. @@ -399,6 +413,20 @@ def tridiagonal( _default_not_implemented("tridiagonal", operator) +def has_real_dtype(operator) -> bool: + """Check if all dtypes in an operator's structure are real (not complex).""" + leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure())) + dtype = jnp.result_type(*leaves) + if jnp.issubdtype(dtype, jnp.complexfloating): + return False + elif jnp.issubdtype(dtype, jnp.floating): + return True + else: + assert False, ( + "Only `jnp.floating` and `jnp.complexfloating` dtypes are understood." + ) + + @ft.singledispatch def is_symmetric(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as symmetric. @@ -417,6 +445,32 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool: _default_not_implemented("is_symmetric", operator) +@ft.singledispatch +def is_hermitian(operator: AbstractLinearOperator) -> bool: + """Returns whether an operator is marked as Hermitian (self-adjoint). + + See [the documentation on linear operator tags](../api/tags.md) for more + information. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + Either `True` or `False.` + """ + # Default for operators that don't register `is_hermitian` explicitly (e.g. custom + # `AbstractLinearOperator`s written before it existed): derive it from the other + # property checks, the same way the built-in operators do minus the + # `hermitian_tag` check, which needs operator-specific `.tags`. + if is_positive_semidefinite(operator) or is_negative_semidefinite(operator): + return True + if has_real_dtype(operator) and is_symmetric(operator): + return True + return False + + @ft.singledispatch def is_diagonal(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as diagonal. diff --git a/lineax/_operator/binary.py b/lineax/_operator/binary.py index 52ad764..898c074 100644 --- a/lineax/_operator/binary.py +++ b/lineax/_operator/binary.py @@ -23,8 +23,10 @@ AbstractLinearOperator, conj, diagonal, + has_real_dtype, has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -203,6 +205,7 @@ def _(operator): for check in ( is_symmetric, + is_hermitian, is_diagonal, is_lower_triangular, is_upper_triangular, @@ -247,6 +250,17 @@ def _(operator): return is_diagonal(operator.operator1) and is_diagonal(operator.operator2) +# is_hermitian: as above, diagonal matrices commute. A product of diagonals is itself +# diagonal, which is Hermitian only when its (complex) entries are real-valued. +@is_hermitian.register(ComposedLinearOperator) +def _(operator): + return ( + is_diagonal(operator.operator1) + and is_diagonal(operator.operator2) + and has_real_dtype(operator) + ) + + # is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but # tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal @is_tridiagonal.register(ComposedLinearOperator) diff --git a/lineax/_operator/core.py b/lineax/_operator/core.py index 20673a2..df9e931 100644 --- a/lineax/_operator/core.py +++ b/lineax/_operator/core.py @@ -38,6 +38,7 @@ ) from .._tags import ( diagonal_tag, + hermitian_tag, lower_triangular_tag, negative_semidefinite_tag, positive_semidefinite_tag, @@ -53,9 +54,11 @@ conj, diagonal, FlatPyTree, + has_real_dtype, has_unit_diagonal, inexact_structure, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -724,20 +727,6 @@ def _(operator): # checks -def _has_real_dtype(operator) -> bool: - """Check if all dtypes in an operator's structure are real (not complex).""" - leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure())) - dtype = jnp.result_type(*leaves) - if jnp.issubdtype(dtype, jnp.complexfloating): - return False - elif jnp.issubdtype(dtype, jnp.floating): - return True - else: - assert False, ( - "Only `jnp.floating` and `jnp.complexfloating` dtypes are understood." - ) - - @is_symmetric.register(MatrixLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @@ -746,12 +735,29 @@ def _(operator): # Symmetric (A = A^T) if explicitly tagged symmetric or diagonal if symmetric_tag in operator.tags or diagonal_tag in operator.tags: return True - # PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian + # PSD/NSD/Hermitian imply A = A^T only for real dtypes if ( positive_semidefinite_tag in operator.tags or negative_semidefinite_tag in operator.tags + or hermitian_tag in operator.tags + ): + return has_real_dtype(operator) + return False + + +@is_hermitian.register(MatrixLinearOperator) +@is_hermitian.register(PyTreeLinearOperator) +@is_hermitian.register(JacobianLinearOperator) +@is_hermitian.register(FunctionLinearOperator) +def _(operator): + if ( + hermitian_tag in operator.tags + or positive_semidefinite_tag in operator.tags + or negative_semidefinite_tag in operator.tags ): - return _has_real_dtype(operator) + return True + if symmetric_tag in operator.tags or diagonal_tag in operator.tags: + return has_real_dtype(operator) return False diff --git a/lineax/_operator/structured.py b/lineax/_operator/structured.py index f0195c6..9a0f0c2 100644 --- a/lineax/_operator/structured.py +++ b/lineax/_operator/structured.py @@ -40,9 +40,11 @@ conj, diagonal, FlatPyTree, + has_real_dtype, has_unit_diagonal, inexact_structure, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -292,6 +294,7 @@ def _(operator): @is_symmetric.register(IdentityLinearOperator) +@is_hermitian.register(IdentityLinearOperator) def _(operator): return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True @@ -301,7 +304,13 @@ def _(operator): return True +@is_hermitian.register(DiagonalLinearOperator) +def _(operator): + return has_real_dtype(operator) + + @is_symmetric.register(TridiagonalLinearOperator) +@is_hermitian.register(TridiagonalLinearOperator) def _(operator): return False diff --git a/lineax/_operator/wrapper.py b/lineax/_operator/wrapper.py index e9d7576..fd28052 100644 --- a/lineax/_operator/wrapper.py +++ b/lineax/_operator/wrapper.py @@ -28,6 +28,7 @@ from .._tags import ( diagonal_tag, + hermitian_tag, lower_triangular_tag, MaxRankTag, negative_semidefinite_tag, @@ -43,8 +44,10 @@ as_frozenset, conj, diagonal, + has_real_dtype, has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -340,6 +343,7 @@ def _(operator): for check in ( is_symmetric, + is_hermitian, is_diagonal, has_unit_diagonal, is_lower_triangular, @@ -371,6 +375,28 @@ def _(operator, check=check): return check(operator.operator) +def _scalar_is_real(scalar) -> bool: + """Whether a scalar is statically known to be real-valued. + + A real dtype guarantees a real value, so this is known even for JAX tracers (whose + runtime value is unknown at trace time): only the dtype matters, not the value. + Returns `False` only for genuinely complex-typed scalars. + """ + return not jnp.issubdtype(jnp.result_type(scalar), jnp.complexfloating) + + +# Hermitian-ness preserved by negation and scaling by any real scalar +@is_hermitian.register(NegLinearOperator) +def _(operator): + return is_hermitian(operator.operator) + + +@is_hermitian.register(MulLinearOperator) +@is_hermitian.register(DivLinearOperator) +def _(operator): + return _scalar_is_real(operator.scalar) and is_hermitian(operator.operator) + + # has_unit_diagonal is NOT preserved by negation @has_unit_diagonal.register(NegLinearOperator) def _(operator): @@ -498,6 +524,26 @@ def _(operator, check=check, tag=tag): return (tag in operator.tags) or check(operator.operator) +# `is_hermitian` is special-cased rather than handled by the loop above: a tag other +# than `hermitian_tag` can still imply Hermitian-ness. PSD/NSD operators are Hermitian +# (real or complex), and real symmetric/diagonal operators are Hermitian too. This +# mirrors the cross-implications encoded for the core operators. +@is_hermitian.register(TaggedLinearOperator) +def _(operator): + tags = operator.tags + if is_hermitian(operator.operator): + return True + if ( + hermitian_tag in tags + or positive_semidefinite_tag in tags + or negative_semidefinite_tag in tags + ): + return True + if symmetric_tag in tags or diagonal_tag in tags: + return has_real_dtype(operator) + return False + + @max_rank.register(TaggedLinearOperator) def _(operator): inner = max_rank(operator.operator) diff --git a/lineax/_solve.py b/lineax/_solve.py index 817d710..a33e18e 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import functools as ft -from typing import Any, Generic, TypeAlias, TypeVar +from typing import Any, TypeAlias import equinox as eqx import equinox.internal as eqxi @@ -26,28 +25,34 @@ import jax.tree_util as jtu from equinox.internal import ω from jax._src.ad_util import stop_gradient_p -from jaxtyping import Array, ArrayLike, PyTree +from jaxtyping import ArrayLike, PyTree from ._custom_types import sentinel from ._misc import inexact_asarray, strip_weak_dtype from ._operator import ( AbstractLinearOperator, - conj, FunctionLinearOperator, IdentityLinearOperator, - is_diagonal, - is_lower_triangular, - is_negative_semidefinite, - is_positive_semidefinite, - is_tridiagonal, - is_upper_triangular, + is_hermitian, linearise, max_rank, + TaggedLinearOperator, TangentLinearOperator, ) from ._solution import RESULTS, Solution +from ._solver import ( + AutoLinearSolver as AutoLinearSolver, + Cholesky, + HEVD, + Normal, + QR, + SVD, +) +from ._solver.base import AbstractLinearSolver as AbstractLinearSolver +from ._solver.misc import pack_structures from ._tags import ( invert_tags, + positive_semidefinite_tag, tags_from_checks, ) @@ -142,6 +147,95 @@ def _linear_solve_abstract_eval(operator, state, vector, options, solver, throw) return out +def _squared_rcond(rcond: float | None, n: int, m: int, dtype) -> float: + """`resolve_rcond(rcond, n, m, dtype) ** 2`, as a Python `float`. + + A pure-Python mirror of [`resolve_rcond`][] (no `jnp.where`), so the result stays a + Python scalar -- HEVD's `rcond` is `float | None`, and under the recursive gram + solve `resolve_rcond` of a non-`None` rcond would otherwise produce a traced array. + + Squaring reproduces the original solver's rank cutoff on the gram's eigenvalues + `σ²`, since `σ² > rcond²·σ²ₘₐₓ <=> σ > rcond·σₘₐₓ`. + """ + eps = float(jnp.finfo(dtype).eps) + if rcond is None: + rcond = 2 * eps * max(n, m) + elif rcond < 0: + rcond = eps + return float(rcond) ** 2 + + +# Solver types whose factorisation *may* cheaply yield the gram (pseudo)inverse +# `(AᴴA)⁺`. Whether one actually does can depend on the state (e.g. `Normal` only when +# tall), so `_has_gram_partner` is the definitive runtime check; see `_gram_partner`. +_MaybeHasGramPartner: TypeAlias = QR | SVD | HEVD | Normal + + +def _has_gram_partner(solver: AbstractLinearSolver, state: Any) -> bool: + """Can the (pseudo)inverse for A^H A actually be inferred from solver's state?""" + if isinstance(solver, Normal): + _, tall, _, _ = state + return tall.value # inner operator is `AᴴA`; when wide it is `AAᴴ` + return isinstance(solver, _MaybeHasGramPartner) + + +def _gram_partner( + solver: _MaybeHasGramPartner, + gram_operator: AbstractLinearOperator, + state: Any, +) -> tuple[AbstractLinearSolver, Any]: + """Return a `(gram_solver, gram_state)` pair such that + `linear_solve_p(gram_operator, gram_state, v, gram_solver)` computes `(AᴴA)⁺ v` + (where `gram_operator` is `AᴴA`). Requires `_has_gram_partner(solver, state)`. + + Each candidate solver's gram partner -- a solver representing the (pseudo)inverse of + the gram matrix `AᴴA` -- is obtained from the existing factorisation with no further + decomposition: + + QR `A = QR` -> `Cholesky`, since `AᴴA = RᴴR` (`R` is the factor) + SVD `A = UΣVᴴ` -> `HEVD` with eigenvectors `V`, eigenvalues `σ²` + HEVD `A = VWVᴴ` -> `HEVD` with eigenvectors `V`, eigenvalues `w²` + Normal (tall) -> its inner solver, which already factorises `AᴴA` + + The JVP uses this to collapse the two nested solves against `Aᴴ` (the inner adjoint + solve and the outer `A⁺`) into one gram solve. Routing it back through + `linear_solve_p` -- rather than applying the factors directly -- keeps it correct + under higher-order autodiff, since the gram solve then uses lineax's + pseudoinverse-aware adjoint rather than differentiating through the factorisation. + """ + if isinstance(solver, Normal): + inner_state, tall, _, _ = state + if not tall.value: + # Wide: the inner solver factorises `AAᴴ`, not `AᴴA`. `_has_gram_partner` + # excludes this, so reaching here is a caller bug. + raise ValueError("`Normal` has a gram partner only for tall operators") + # Tall: the inner solver already factorises `AᴴA`, so its state *is* the gram + # state. This holds for any inner solver (Cholesky, CG, HEVD, ...). + return solver.inner_solver, inner_state + if isinstance(solver, QR): + (a, _), transpose, _ = state + if transpose.value: + # QR is full rank, so the JVP reaches the gram path only when + # `rows > columns` (tall), where the stored factorisation is of `A`. + raise ValueError("`QR` has a gram partner only for tall operators") + # Tall `A = QR` => `AᴴA = RᴴR`: the QR factor `R` is the upper Cholesky factor. + r = a[: a.shape[1]] + return Cholesky(), (r, eqxi.Static(False)) + packed = pack_structures(gram_operator) + if isinstance(solver, SVD): + (u, s, vt), _ = state + # `(AᴴA)⁺ = V Σ⁻² Vᴴ`. + eigenvalues, eigenvectors = s**2, vt.conj().T + rcond = _squared_rcond(solver.rcond, vt.shape[1], u.shape[0], s.dtype) + else: + (w, eigenvectors), _ = state + # `(AᴴA)⁺ = (A²)⁺ = V W⁻² Vᴴ`. + eigenvalues = w**2 + m = eigenvectors.shape[0] + rcond = _squared_rcond(solver.rcond, m, m, w.dtype) + return HEVD(rcond=rcond), ((eigenvalues, eigenvectors), packed) + + @eqxi.filter_primitive_jvp def _linear_solve_jvp(primals, tangents): operator, state, vector, options, solver, throw = primals @@ -208,25 +302,55 @@ def _linear_solve_jvp(primals, tangents): assume_independent_rows = solver.assume_full_rank() and rows <= columns assume_independent_columns = solver.assume_full_rank() and columns <= rows if not assume_independent_rows or not assume_independent_columns: - operator_conj_transpose = conj(operator).transpose() - t_operator_conj_transpose = conj(t_operator).transpose() - state_conj, options_conj = solver.conj(state, options) - state_conj_transpose, options_conj_transpose = solver.transpose( - state_conj, options_conj - ) + operator_conj_transpose = operator.H + t_operator_conj_transpose = t_operator.H + if is_hermitian(operator): + # `Aᴴ = A`, so `init(Aᴴ) == init(A)`: the existing state already serves + # as the adjoint state. This holds for any solver, so the fast path is + # keyed on the operator rather than on the solver. + state_conj_transpose, options_conj_transpose = state, options + else: + state_conj, options_conj = solver.conj(state, options) + state_conj_transpose, options_conj_transpose = solver.transpose( + state_conj, options_conj + ) if not assume_independent_rows: lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore - tmp, _, _ = eqxi.filter_primitive_bind( - linear_solve_p, - operator_conj_transpose, # pyright: ignore - state_conj_transpose, # pyright: ignore - tmp, - options_conj_transpose, # pyright: ignore - solver, - True, - ) - vecs.append(tmp) + # This term is `A⁺ (Aᴴ)⁺ w = (AᴴA)⁺ w`. If the solver has a gram partner, + # compute `(AᴴA)⁺ w` in a single gram solve against `AᴴA`; otherwise fall + # back to the generic nested adjoint solve (whose result is later + # left-multiplied by `A⁺` along with the other `vecs`). The gram operator + # is never materialised -- the gram solve reads only `gram_state` -- but it + # carries the right structure and (for higher-order autodiff) tangent. + if _has_gram_partner(solver, state): + gram_operator = TaggedLinearOperator( + operator.H @ operator, positive_semidefinite_tag + ) + gram_solver, gram_state = _gram_partner(solver, gram_operator, state) + gram_inv, _, _ = eqxi.filter_primitive_bind( + linear_solve_p, + gram_operator, + gram_state, + tmp, + {}, + gram_solver, + True, + ) + # `(AᴴA)⁺ w` already lives in the input space, so it bypasses the + # outer `A⁺`: append directly to the already-solved `sols`. + sols.append(gram_inv) + else: + tmp, _, _ = eqxi.filter_primitive_bind( + linear_solve_p, + operator_conj_transpose, # pyright: ignore + state_conj_transpose, # pyright: ignore + tmp, + options_conj_transpose, # pyright: ignore + solver, + True, + ) + vecs.append(tmp) if not assume_independent_columns: tmp1, _, _ = eqxi.filter_primitive_bind( @@ -331,149 +455,6 @@ def _linear_solve_transpose(inputs, cts_out): # -_SolverState = TypeVar("_SolverState") - - -class AbstractLinearSolver(eqx.Module, Generic[_SolverState]): - """Abstract base class for all linear solvers.""" - - @abc.abstractmethod - def init( - self, operator: AbstractLinearOperator, options: dict[str, Any] - ) -> _SolverState: - """Do any initial computation on just the `operator`. - - For example, an LU solver would compute the LU decomposition of the operator - (and this does not require knowing the vector yet). - - It is common to need to solve the linear system `Ax=b` multiple times in - succession, with the same operator `A` and multiple vectors `b`. This method - improves efficiency by making it possible to re-use the computation performed - on just the operator. - - !!! Example - - ```python - operator = lx.MatrixLinearOperator(...) - vector1 = ... - vector2 = ... - solver = lx.LU() - state = solver.init(operator, options={}) - solution1 = lx.linear_solve(operator, vector1, solver, state=state) - solution2 = lx.linear_solve(operator, vector2, solver, state=state) - ``` - - **Arguments:** - - - `operator`: a linear operator. - - `options`: a dictionary of any extra options that the solver may wish to - accept. - - **Returns:** - - A PyTree of arbitrary Python objects. - """ - - @abc.abstractmethod - def compute( - self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any] - ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - """Solves a linear system. - - **Arguments:** - - - `state`: as returned from [`lineax.AbstractLinearSolver.init`][]. - - `vector`: the vector to solve against. - - `options`: a dictionary of any extra options that the solver may wish to - accept. For example, [`lineax.CG`][] accepts a `preconditioner` option. - - **Returns:** - - A 3-tuple of: - - - The solution to the linear system. - - An integer indicating the success or failure of the solve. This is an integer - which may be converted to a human-readable error message via - `lx.RESULTS[...]`. - - A dictionary of an extra statistics about the solve, e.g. the number of steps - taken. - """ - - @abc.abstractmethod - def transpose( - self, state: _SolverState, options: dict[str, Any] - ) -> tuple[_SolverState, dict[str, Any]]: - """Transposes the result of [`lineax.AbstractLinearSolver.init`][]. - - That is, it should be the case that - ```python - state_transpose, _ = solver.transpose(solver.init(operator, options), options) - state_transpose2 = solver.init(operator.T, options) - ``` - must be identical to each other. - - It is relatively common (in particular when differentiating through a linear - solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it - possible to avoid computing both `solver.init(operator)` and - `solver.init(operator.T)` if one can be cheaply computed from the other. - - **Arguments:** - - - `state`: as returned from `solver.init`. - - `options`: any extra options that were passed to `solve.init`. - - **Returns:** - - A 2-tuple of: - - - The state of the transposed operator. - - The options for the transposed operator. - """ - - @abc.abstractmethod - def conj( - self, state: _SolverState, options: dict[str, Any] - ) -> tuple[_SolverState, dict[str, Any]]: - """Conjugate the result of [`lineax.AbstractLinearSolver.init`][]. - - That is, it should be the case that - ```python - state_conj, _ = solver.conj(solver.init(operator, options), options) - state_conj2 = solver.init(conj(operator), options) - ``` - must be identical to each other. - - **Arguments:** - - - `state`: as returned from `solver.init`. - - `options`: any extra options that were passed to `solve.init`. - - **Returns:** - - A 2-tuple of: - - - The state of the conjugated operator. - - The options for the conjugated operator. - """ - - @abc.abstractmethod - def assume_full_rank(self) -> bool: - """Does this solver assume that all operators are full rank? - - When `False`, a more expensive backward pass is needed to account for - the extra generality. In a custom linear solver, it is always safe to - return False. - - **Arguments:** - - Nothing. - - **Returns:** - - Either `True` or `False`. - """ - - def _check_rank_compat( solver: "AbstractLinearSolver", operator: AbstractLinearOperator ): @@ -490,177 +471,6 @@ def _check_rank_compat( ) -_qr_token = eqxi.str2jax("qr_token") -_diagonal_token = eqxi.str2jax("diagonal_token") -_well_posed_diagonal_token = eqxi.str2jax("well_posed_diagonal_token") -_tridiagonal_token = eqxi.str2jax("tridiagonal_token") -_triangular_token = eqxi.str2jax("triangular_token") -_cholesky_token = eqxi.str2jax("cholesky_token") -_lu_token = eqxi.str2jax("lu_token") -_svd_token = eqxi.str2jax("svd_token") - - -# Ugly delayed import because we have the dependency chain -# linear_solve -> AutoLinearSolver -> {Cholesky,...} -> AbstractLinearSolver -# but we want linear_solver and AbstractLinearSolver in the same file. -def _lookup(token) -> AbstractLinearSolver: - from . import _solver - - # pyright doesn't know that these keys are hashable - _lookup_dict = { - _qr_token: _solver.QR(), # pyright: ignore - _diagonal_token: _solver.Diagonal(), # pyright: ignore - _well_posed_diagonal_token: _solver.Diagonal( # pyright: ignore - well_posed=True - ), - _tridiagonal_token: _solver.Tridiagonal(), # pyright: ignore - _triangular_token: _solver.Triangular(), # pyright: ignore - _cholesky_token: _solver.Cholesky(), # pyright: ignore - _lu_token: _solver.LU(), # pyright: ignore - _svd_token: _solver.SVD(), # pyright: ignore - } - return _lookup_dict[token] - - -_AutoLinearSolverState: TypeAlias = tuple[Any, Any] - - -class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): - """Automatically determines a good linear solver based on the structure of the - operator. - - - If `well_posed=True`: - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - - If the operator is triangular, then use [`lineax.Triangular`][]. - - If the matrix is positive or negative (semi-)definite, then use - [`lineax.Cholesky`][]. - - Else use [`lineax.LU`][]. - - This is a good choice if you want to be certain that an error is raised for - ill-posed systems. - - - If `well_posed=False`: - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - Else use [`lineax.SVD`][]. - - This is a good choice if you want to be certain that you can handle ill-posed - systems. - - - If `well_posed=None`: - - If the operator is non-square, then use [`lineax.QR`][]. - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - - If the operator is triangular, then use [`lineax.Triangular`][]. - - If the matrix is positive or negative (semi-)definite, then use - [`lineax.Cholesky`][]. - - Else, use [`lineax.LU`][]. - - This is a good choice if your primary concern is computational efficiency. It will - handle ill-posed systems as long as it is not computationally expensive to do so. - """ - - well_posed: bool | None - - def _select_solver(self, operator: AbstractLinearOperator): - if self.well_posed is True: - if operator.in_size() != operator.out_size(): - raise ValueError( - "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square " - "operator. If you are trying solve a least-squares problem then " - "you should pass `solver=AutoLinearSolver(well_posed=False)`. By " - "default `lineax.linear_solve` assumes that the operator is " - "square and nonsingular." - ) - if is_diagonal(operator): - token = _well_posed_diagonal_token - elif is_tridiagonal(operator): - token = _tridiagonal_token - elif is_lower_triangular(operator) or is_upper_triangular(operator): - token = _triangular_token - elif is_positive_semidefinite(operator) or is_negative_semidefinite( - operator - ): - token = _cholesky_token - else: - token = _lu_token - elif self.well_posed is False: - if is_diagonal(operator): - token = _diagonal_token - else: - # TODO: use rank-revealing QR instead. - token = _svd_token - elif self.well_posed is None: - if operator.in_size() != operator.out_size(): - token = _qr_token - elif is_diagonal(operator): - token = _diagonal_token - elif is_tridiagonal(operator): - token = _tridiagonal_token - elif is_lower_triangular(operator) or is_upper_triangular(operator): - token = _triangular_token - elif is_positive_semidefinite(operator) or is_negative_semidefinite( - operator - ): - token = _cholesky_token - else: - token = _lu_token - else: - raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") - return token - - def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: - """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. - - **Arguments:** - - - `operator`: a linear operator. - - **Returns:** - - The linear solver that will be used. - """ - return _lookup(self._select_solver(operator)) - - def init(self, operator, options) -> _AutoLinearSolverState: - token = self._select_solver(operator) - return token, _lookup(token).init(operator, options) - - def compute( - self, - state: _AutoLinearSolverState, - vector: PyTree[Array], - options: dict[str, Any], - ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - token, state = state - solver = _lookup(token) - solution, result, _ = solver.compute(state, vector, options) - return solution, result, {} - - def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): - token, state = state - solver = _lookup(token) - transpose_state, transpose_options = solver.transpose(state, options) - transpose_state = (token, transpose_state) - return transpose_state, transpose_options - - def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): - token, state = state - solver = _lookup(token) - conj_state, conj_options = solver.conj(state, options) - conj_state = (token, conj_state) - return conj_state, conj_options - - def assume_full_rank(self): - return self.well_posed is not False - - -AutoLinearSolver.__init__.__doc__ = """**Arguments:** - -- `well_posed`: whether to only handle well-posed systems or not, as discussed above. -""" - - # TODO(kidger): gmres, bicgstab # TODO(kidger): support auxiliary outputs @eqx.filter_jit diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index 2cee02c..0468326 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .auto import AutoLinearSolver as AutoLinearSolver from .bicgstab import BiCGStab as BiCGStab from .cg import CG as CG, NormalCG as NormalCG from .cholesky import Cholesky as Cholesky from .diagonal import Diagonal as Diagonal from .gmres import GMRES as GMRES +from .hevd import HEVD as HEVD from .lsmr import LSMR as LSMR from .lu import LU as LU from .normal import Normal as Normal diff --git a/lineax/_solver/auto.py b/lineax/_solver/auto.py new file mode 100644 index 0000000..4e6f4fc --- /dev/null +++ b/lineax/_solver/auto.py @@ -0,0 +1,179 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeAlias + +from jaxtyping import Array, PyTree + +from .._operator import ( + AbstractLinearOperator, + is_diagonal, + is_hermitian, + is_lower_triangular, + is_negative_semidefinite, + is_positive_semidefinite, + is_tridiagonal, + is_upper_triangular, +) +from .._solution import RESULTS +from .base import AbstractLinearSolver +from .cholesky import Cholesky +from .diagonal import Diagonal +from .hevd import HEVD +from .lu import LU +from .qr import QR +from .svd import SVD +from .triangular import Triangular +from .tridiagonal import Tridiagonal + + +_AutoLinearSolverState: TypeAlias = tuple[AbstractLinearSolver, Any] + + +class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): + """Automatically determines a good linear solver based on the structure of the + operator. + + - If `well_posed=True`: + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is triangular, then use [`lineax.Triangular`][]. + - If the matrix is positive or negative (semi-)definite, then use + [`lineax.Cholesky`][]. + - Else use [`lineax.LU`][]. + + This is a good choice if you want to be certain that an error is raised for + ill-posed systems. + + - If `well_posed=False`: + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is Hermitian, then use [`lineax.HEVD`][]. + - Else use [`lineax.SVD`][]. + + This is a good choice if you want to be certain that you can handle ill-posed + systems. + + - If `well_posed=None`: + - If the operator is non-square, then use [`lineax.QR`][]. + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is triangular, then use [`lineax.Triangular`][]. + - If the matrix is positive or negative (semi-)definite, then use + [`lineax.Cholesky`][]. + - Else, use [`lineax.LU`][]. + + This is a good choice if your primary concern is computational efficiency. It will + handle ill-posed systems as long as it is not computationally expensive to do so. + """ + + well_posed: bool | None + + def _select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: + if self.well_posed is True: + if operator.in_size() != operator.out_size(): + raise ValueError( + "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square " + "operator. If you are trying solve a least-squares problem then " + "you should pass `solver=AutoLinearSolver(well_posed=False)`. By " + "default `lineax.linear_solve` assumes that the operator is " + "square and nonsingular." + ) + if is_diagonal(operator): + solver = Diagonal(well_posed=True) + elif is_tridiagonal(operator): + solver = Tridiagonal() + elif is_lower_triangular(operator) or is_upper_triangular(operator): + solver = Triangular() + elif is_positive_semidefinite(operator) or is_negative_semidefinite( + operator + ): + solver = Cholesky() + else: + solver = LU() + elif self.well_posed is False: + if is_diagonal(operator): + solver = Diagonal() + elif is_hermitian(operator): + # A Hermitian eigendecomposition is cheaper than a general SVD, and + # handles ill-posed Hermitian systems via the same pseudoinverse. + solver = HEVD() + else: + # TODO: use rank-revealing QR instead. + solver = SVD() + elif self.well_posed is None: + if operator.in_size() != operator.out_size(): + solver = QR() + elif is_diagonal(operator): + solver = Diagonal() + elif is_tridiagonal(operator): + solver = Tridiagonal() + elif is_lower_triangular(operator) or is_upper_triangular(operator): + solver = Triangular() + elif is_positive_semidefinite(operator) or is_negative_semidefinite( + operator + ): + solver = Cholesky() + else: + solver = LU() + else: + raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") + return solver + + def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: + """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + The linear solver that will be used. + """ + return self._select_solver(operator) + + def init(self, operator, options) -> _AutoLinearSolverState: + solver = self._select_solver(operator) + return solver, solver.init(operator, options) + + def compute( + self, + state: _AutoLinearSolverState, + vector: PyTree[Array], + options: dict[str, Any], + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + solver, state = state + solution, result, _ = solver.compute(state, vector, options) + return solution, result, {} + + def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): + solver, state = state + transpose_state, transpose_options = solver.transpose(state, options) + transpose_state = (solver, transpose_state) + return transpose_state, transpose_options + + def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): + solver, state = state + conj_state, conj_options = solver.conj(state, options) + conj_state = (solver, conj_state) + return conj_state, conj_options + + def assume_full_rank(self): + return self.well_posed is not False + + +AutoLinearSolver.__init__.__doc__ = """**Arguments:** + +- `well_posed`: whether to only handle well-posed systems or not, as discussed above. +""" diff --git a/lineax/_solver/base.py b/lineax/_solver/base.py new file mode 100644 index 0000000..9984221 --- /dev/null +++ b/lineax/_solver/base.py @@ -0,0 +1,165 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Generic, TypeVar + +import equinox as eqx +from jaxtyping import Array, PyTree + +from .._operator import AbstractLinearOperator +from .._solution import RESULTS + + +_SolverState = TypeVar("_SolverState") + + +class AbstractLinearSolver(eqx.Module, Generic[_SolverState]): + """Abstract base class for all linear solvers.""" + + @abc.abstractmethod + def init( + self, operator: AbstractLinearOperator, options: dict[str, Any] + ) -> _SolverState: + """Do any initial computation on just the `operator`. + + For example, an LU solver would compute the LU decomposition of the operator + (and this does not require knowing the vector yet). + + It is common to need to solve the linear system `Ax=b` multiple times in + succession, with the same operator `A` and multiple vectors `b`. This method + improves efficiency by making it possible to re-use the computation performed + on just the operator. + + !!! Example + + ```python + operator = lx.MatrixLinearOperator(...) + vector1 = ... + vector2 = ... + solver = lx.LU() + state = solver.init(operator, options={}) + solution1 = lx.linear_solve(operator, vector1, solver, state=state) + solution2 = lx.linear_solve(operator, vector2, solver, state=state) + ``` + + **Arguments:** + + - `operator`: a linear operator. + - `options`: a dictionary of any extra options that the solver may wish to + accept. + + **Returns:** + + A PyTree of arbitrary Python objects. + """ + + @abc.abstractmethod + def compute( + self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any] + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + """Solves a linear system. + + **Arguments:** + + - `state`: as returned from [`lineax.AbstractLinearSolver.init`][]. + - `vector`: the vector to solve against. + - `options`: a dictionary of any extra options that the solver may wish to + accept. For example, [`lineax.CG`][] accepts a `preconditioner` option. + + **Returns:** + + A 3-tuple of: + + - The solution to the linear system. + - An integer indicating the success or failure of the solve. This is an integer + which may be converted to a human-readable error message via + `lx.RESULTS[...]`. + - A dictionary of an extra statistics about the solve, e.g. the number of steps + taken. + """ + + @abc.abstractmethod + def transpose( + self, state: _SolverState, options: dict[str, Any] + ) -> tuple[_SolverState, dict[str, Any]]: + """Transposes the result of [`lineax.AbstractLinearSolver.init`][]. + + That is, it should be the case that + ```python + state_transpose, _ = solver.transpose(solver.init(operator, options), options) + state_transpose2 = solver.init(operator.T, options) + ``` + must be identical to each other. + + It is relatively common (in particular when differentiating through a linear + solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it + possible to avoid computing both `solver.init(operator)` and + `solver.init(operator.T)` if one can be cheaply computed from the other. + + **Arguments:** + + - `state`: as returned from `solver.init`. + - `options`: any extra options that were passed to `solve.init`. + + **Returns:** + + A 2-tuple of: + + - The state of the transposed operator. + - The options for the transposed operator. + """ + + @abc.abstractmethod + def conj( + self, state: _SolverState, options: dict[str, Any] + ) -> tuple[_SolverState, dict[str, Any]]: + """Conjugate the result of [`lineax.AbstractLinearSolver.init`][]. + + That is, it should be the case that + ```python + state_conj, _ = solver.conj(solver.init(operator, options), options) + state_conj2 = solver.init(conj(operator), options) + ``` + must be identical to each other. + + **Arguments:** + + - `state`: as returned from `solver.init`. + - `options`: any extra options that were passed to `solve.init`. + + **Returns:** + + A 2-tuple of: + + - The state of the conjugated operator. + - The options for the conjugated operator. + """ + + @abc.abstractmethod + def assume_full_rank(self) -> bool: + """Does this solver assume that all operators are full rank? + + When `False`, a more expensive backward pass is needed to account for + the extra generality. In a custom linear solver, it is always safe to + return False. + + **Arguments:** + + Nothing. + + **Returns:** + + Either `True` or `False`. + """ diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 31cf241..65ce784 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -25,7 +25,7 @@ from .._norm import max_norm, tree_dot from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index eee8c97..cd9655a 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -34,7 +34,7 @@ linearise, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 from .normal import Normal diff --git a/lineax/_solver/cholesky.py b/lineax/_solver/cholesky.py index 852ab70..6a352fe 100644 --- a/lineax/_solver/cholesky.py +++ b/lineax/_solver/cholesky.py @@ -25,7 +25,7 @@ is_positive_semidefinite, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver _CholeskyState: TypeAlias = tuple[Array, eqxi.Static] diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 334e96e..c271639 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -20,7 +20,7 @@ from .._misc import resolve_rcond from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index d5911a0..aabe4bf 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -28,7 +28,7 @@ from .._norm import max_norm, two_norm from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator from .._solution import RESULTS -from .._solve import AbstractLinearSolver, linear_solve +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 from .qr import QR @@ -300,6 +300,8 @@ def buffers(carry): ) coeff_op_transpose = MatrixLinearOperator(coeff_mat.T) # TODO(raderj): move to a Hessenberg-specific solver + from .._solve import linear_solve + z = linear_solve(coeff_op_transpose, beta_vec, QR(), throw=False).value diff = jtu.tree_map( lambda mat: jnp.tensordot( diff --git a/lineax/_solver/hevd.py b/lineax/_solver/hevd.py new file mode 100644 index 0000000..7a5492b --- /dev/null +++ b/lineax/_solver/hevd.py @@ -0,0 +1,155 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeAlias + +import equinox as eqx +import jax.lax as lax +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from .._misc import resolve_rcond +from .._operator import ( + AbstractLinearOperator, + is_hermitian, + is_negative_semidefinite, + is_positive_semidefinite, + max_rank, +) +from .._solution import RESULTS +from .base import AbstractLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + unravel_solution, +) + + +_HEVDState: TypeAlias = tuple[tuple[Array, Array], PackedStructures] + + +class HEVD(AbstractLinearSolver[_HEVDState]): + """Eigenvalue decomposition solver for Hermitian linear systems. + + The operator must be square and Hermitian (self-adjoint), but need not be + definite or even nonsingular: in the singular case this solver returns + the pseudoinverse solution. This is the optimised Hermitian analogue of + [`lineax.SVD`][]. + """ + + rcond: float | None = None + + def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): + del options + if not is_hermitian(operator): + raise ValueError( + "`HEVD` may only be used for linear solves with Hermitian matrices." + ) + # `jnp.linalg.eigh` returns eigenvalues in ascending (signed) order + w, v = jnp.linalg.eigh(operator.as_matrix()) + r = max_rank(operator) + if r < w.shape[0]: + # The operator is declared to have rank at most `r`, so all but the `r` + # largest-magnitude eigenvalues are mathematically zero. Statically drop + # them to shrink the matmuls (and storage) in `compute`. + m = v.shape[0] + rcond = resolve_rcond(self.rcond, m, m, w.dtype) * jnp.max(jnp.abs(w)) + if is_positive_semidefinite(operator): + # Eigenvalues are >= 0, so in ascending order the `r` largest are a + # contiguous trailing slice (cheaper than a reordering gather). + w, v, dropped = w[m - r :], v[:, m - r :], w[: m - r] + elif is_negative_semidefinite(operator): + # Eigenvalues are <= 0, so the `r` largest in magnitude are a + # contiguous leading slice. + w, v, dropped = w[:r], v[:, :r], w[r:] + else: + # Indefinite: the small-magnitude eigenvalues sit in the interior of + # the spectrum, so no contiguous slice works. Reorder by descending + # magnitude (an O(n^2) gather, dominated by the O(n^3) eigensolve) + # and take the leading `r`. + order = jnp.argsort(jnp.abs(w))[::-1] + w, v = w[order], v[:, order] + w, v, dropped = w[:r], v[:, :r], w[r:] + # `compute` masks out `|w_i| <= rcond * max|w|`, so dropping these is + # lossless iff they all sit below that floor. Otherwise the `max_rank` + # claim is false (truncation would change the solution), so error out. + # Checking the largest discarded magnitude also catches a mistagged + # PSD/NSD operator whose true large eigenvalues sit on the dropped side. + w = eqx.error_if( + w, + jnp.max(jnp.abs(dropped)) > rcond, + "lineax.HEVD: the operator was declared (via a `MaxRankTag`, or by " + f"composition rules) to have rank at most {r}, but it has an " + "eigenvalue above the rcond threshold beyond that rank. Truncating to " + "the declared rank would change the solution, so the rank claim " + "appears to be incorrect. Remove/loosen the rank tag, increase " + "`rcond` if you intend a low-rank approximation, or set " + "`EQX_ON_ERROR=off` to skip this check.", + ) + packed_structures = pack_structures(operator) + return (w, v), packed_structures + + def compute( + self, + state: _HEVDState, + vector: PyTree[Array], + options: dict[str, Any], + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + del options + (w, v), packed_structures = state + vector = ravel_vector(vector, packed_structures) + m = v.shape[0] + rcond = resolve_rcond(self.rcond, m, m, w.dtype) + rcond = jnp.array(rcond, dtype=w.dtype) + abs_w = jnp.abs(w) + if w.size > 0: + # Eigenvalues are signed and not magnitude-sorted; scale by max |w|. + rcond = rcond * jnp.max(abs_w) + # Not >=, or this fails with a matrix of all-zeros. + mask = abs_w > rcond + rank = mask.sum() + safe_w = jnp.where(mask, w, 1) + w_inv = jnp.where(mask, jnp.array(1.0) / safe_w, 0).astype(v.dtype) + vHb = jnp.matmul(v.conj().T, vector, precision=lax.Precision.HIGHEST) + solution = jnp.matmul(v, w_inv * vHb, precision=lax.Precision.HIGHEST) + solution = unravel_solution(solution, packed_structures) + return solution, RESULTS.successful, {"rank": rank} + + def transpose(self, state: _HEVDState, options: dict[str, Any]): + del options + (w, v), packed_structures = state + # `A` is Hermitian, so `A^T = conj(A)` (with `w` real). The structure is + # square symmetric, so the packed structures are unchanged. + transpose_state = (w, v.conj()), packed_structures + transpose_options = {} + return transpose_state, transpose_options + + def conj(self, state: _HEVDState, options: dict[str, Any]): + del options + (w, v), packed_structures = state + # `A` is Hermitian, so `conj(A) = conj(V) diag(w) conj(V)^H` (with `w` real). + conj_state = (w, v.conj()), packed_structures + conj_options = {} + return conj_state, conj_options + + def assume_full_rank(self): + return False + + +HEVD.__init__.__doc__ = """**Arguments**: + +- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine + precision times `N`, where `(N, N)` is the shape of the operator. +""" diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 8491da0..9cbf092 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -46,7 +46,7 @@ from .._norm import two_norm from .._operator import AbstractLinearOperator, conj, linearise, max_rank from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver _LSMRState: TypeAlias = AbstractLinearOperator diff --git a/lineax/_solver/lu.py b/lineax/_solver/lu.py index 7283600..a751186 100644 --- a/lineax/_solver/lu.py +++ b/lineax/_solver/lu.py @@ -21,7 +21,7 @@ from .._operator import AbstractLinearOperator, is_diagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py index e46aabd..a1d205e 100644 --- a/lineax/_solver/normal.py +++ b/lineax/_solver/normal.py @@ -18,11 +18,18 @@ import equinox.internal as eqxi from jaxtyping import Array, PyTree -from .._operator import conj, linearise, materialise, TaggedLinearOperator +from .._operator import ( + AbstractLinearOperator, + conj, + linearise, + materialise, + TaggedLinearOperator, +) from .._solution import RESULTS -from .._solve import AbstractLinearOperator, AbstractLinearSolver from .._tags import positive_semidefinite_tag +from .base import AbstractLinearSolver from .cholesky import Cholesky +from .hevd import HEVD _InnerSolverState = TypeVar("_InnerSolverState") @@ -107,8 +114,8 @@ def init(self, operator, options): # Cholesky materialises op twice when computing (op^H @ op).as_matrix() # Cheaper to materialise first and then conjugate-transpose. # For iterative solvers we only linearise to avoid eager materialisation. - is_cholesky = isinstance(self.inner_solver, Cholesky) - lin_op = materialise(operator) if is_cholesky else linearise(operator) + is_direct = isinstance(self.inner_solver, Cholesky | HEVD) + lin_op = materialise(operator) if is_direct else linearise(operator) if tall: inner_operator = conj(lin_op.transpose()) @ lin_op else: diff --git a/lineax/_solver/qr.py b/lineax/_solver/qr.py index 69e4b42..c7c57a0 100644 --- a/lineax/_solver/qr.py +++ b/lineax/_solver/qr.py @@ -21,7 +21,7 @@ from jaxtyping import Array, PyTree from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/svd.py b/lineax/_solver/svd.py index 5b87209..3c4bfeb 100644 --- a/lineax/_solver/svd.py +++ b/lineax/_solver/svd.py @@ -23,7 +23,7 @@ from .._misc import resolve_rcond from .._operator import AbstractLinearOperator, max_rank from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/triangular.py b/lineax/_solver/triangular.py index 304f21a..e0d1674 100644 --- a/lineax/_solver/triangular.py +++ b/lineax/_solver/triangular.py @@ -25,7 +25,7 @@ is_upper_triangular, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/tridiagonal.py b/lineax/_solver/tridiagonal.py index 7d83eb7..e1f72e7 100644 --- a/lineax/_solver/tridiagonal.py +++ b/lineax/_solver/tridiagonal.py @@ -20,7 +20,7 @@ from .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_tags.py b/lineax/_tags.py index 83051a8..b8d6aa8 100644 --- a/lineax/_tags.py +++ b/lineax/_tags.py @@ -29,6 +29,7 @@ def __repr__(self): symmetric_tag = _HasRepr("symmetric_tag") +hermitian_tag = _HasRepr("hermitian_tag") diagonal_tag = _HasRepr("diagonal_tag") tridiagonal_tag = _HasRepr("tridiagonal_tag") unit_diagonal_tag = _HasRepr("unit_diagonal_tag") @@ -127,6 +128,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: from ._operator import ( has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -140,6 +142,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: tag for check, tag in [ (is_symmetric, symmetric_tag), + (is_hermitian, hermitian_tag), (is_diagonal, diagonal_tag), (is_lower_triangular, lower_triangular_tag), (is_upper_triangular, upper_triangular_tag), @@ -163,6 +166,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: for tag in ( symmetric_tag, + hermitian_tag, unit_diagonal_tag, diagonal_tag, positive_semidefinite_tag, @@ -227,6 +231,7 @@ def transpose_tags(tags: frozenset[object]): for tag in ( symmetric_tag, + hermitian_tag, diagonal_tag, lower_triangular_tag, upper_triangular_tag, diff --git a/tests/helpers.py b/tests/helpers.py index 3194d11..8b4a8d7 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -49,6 +49,8 @@ def _construct_matrix_impl( matrix = jnp.diag(jnp.diag(matrix)) if has_tag(tags, lx.symmetric_tag): matrix = matrix + matrix.T + if has_tag(tags, lx.hermitian_tag): + matrix = matrix + matrix.conj().T if has_tag(tags, lx.lower_triangular_tag): matrix = jnp.tril(matrix) if has_tag(tags, lx.upper_triangular_tag): @@ -64,6 +66,15 @@ def _construct_matrix_impl( matrix = matrix @ matrix.T.conj() if has_tag(tags, lx.negative_semidefinite_tag): matrix = -matrix @ matrix.T.conj() + if cond_or_singular == "zero" and ( + has_tag(tags, lx.symmetric_tag) or has_tag(tags, lx.hermitian_tag) + ): + # The symmetric/Hermitian construction refills the leading row that + # `zero` cleared, so re-zero the leading row *and* column of the result. + # This makes `e_0` a null vector -- a genuinely rank-deficient, and still + # indefinite, Hermitian operator -- mirroring how `zero` yields a + # rank-deficient matrix for the PSD/NSD constructions. + matrix = matrix.at[0, :].set(0).at[:, 0].set(0) if isinstance(cond_or_singular, str): break else: @@ -84,7 +95,10 @@ def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64): - if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)): + if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES, lx.HEVD)): + # `trim_row`/`trim_col` produce non-square matrices, which are incompatible + # with the (square) structure these solvers require. Use `zero` instead, + # which keeps the matrix square and (for PSD/NSD/Hermitian tags) Hermitian. singular_method = "zero" else: # Use `getkey()` rather than the stdlib `random.choice` for reproducibility @@ -133,6 +147,10 @@ def construct_poisson_matrix(size, dtype=jnp.float64): (lx.Cholesky(), lx.positive_semidefinite_tag, False), (lx.Cholesky(), lx.negative_semidefinite_tag, False), (lx.Normal(lx.Cholesky()), (), False), + (lx.HEVD(), lx.positive_semidefinite_tag, True), + (lx.HEVD(), lx.negative_semidefinite_tag, True), + (lx.HEVD(), lx.hermitian_tag, True), + (lx.Normal(lx.HEVD()), (), True), ] solvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse] solvers = [a for a, _, _ in solvers_tags_pseudoinverse] diff --git a/tests/test_max_rank.py b/tests/test_max_rank.py index a867547..8c0eb9e 100644 --- a/tests/test_max_rank.py +++ b/tests/test_max_rank.py @@ -256,3 +256,87 @@ def test_svd_raises_when_max_rank_too_small(): vector = jnp.arange(10.0) + 0.5 with pytest.raises(eqx.EquinoxRuntimeError): lx.linear_solve(operator, vector, lx.SVD()) + + +# --------------------------------------------------------------------------- +# HEVD truncation +# --------------------------------------------------------------------------- + + +def _hermitian_with_spectrum(key, eigvals): + # `Q diag(eigvals) Q^T` with `Q` orthogonal: a (real-)Hermitian matrix with the + # given eigenvalues. Zeros placed between nonzero eigenvalues of both signs land + # in the interior of eigh's ascending order, exercising the fact that HEVD must + # truncate by *magnitude* (unlike SVD, where the small values are a contiguous + # tail). + size = len(eigvals) + q, _ = jnp.linalg.qr(jax.random.normal(key, (size, size))) + d = jnp.asarray(eigvals, dtype=q.dtype) + return (q * d[None, :]) @ q.T + + +def test_hevd_truncates_state_to_max_rank(): + # A genuinely rank-2 indefinite Hermitian matrix, declared rank 2: HEVD's state is + # truncated to 2 (eigenvalue, eigenvector) pairs and the solution is unchanged. + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), [3.0, -2.0, 0.0, 0.0, 0.0]) + solver = lx.HEVD() + + plain = lx.MatrixLinearOperator(matrix, lx.hermitian_tag) + (w_full, v_full), _ = solver.init(plain, {}) + assert w_full.shape == (5,) + assert v_full.shape == (5, 5) + + tagged = lx.MatrixLinearOperator(matrix, (lx.hermitian_tag, lx.MaxRankTag(2))) + (w_t, v_t), _ = solver.init(tagged, {}) + assert w_t.shape == (2,) + assert v_t.shape == (5, 2) + + vector = jnp.arange(5.0) + 0.5 + x_plain = lx.linear_solve(plain, vector, solver).value + x_tagged = lx.linear_solve(tagged, vector, solver).value + assert jnp.allclose(x_plain, x_tagged) + + +# PSD (eigenvalues >= 0) and NSD (<= 0) operators have their near-zero eigenvalues at +# a contiguous *end* of eigh's ascending order, so truncation is a slice rather than a +# reorder. Indefinite operators need the reordering gather. Cover all three branches. +@pytest.mark.parametrize( + "tag, eigvals", + ( + (lx.hermitian_tag, [3.0, -2.0, 0.0, 0.0, 0.0]), # indefinite -> reorder + (lx.positive_semidefinite_tag, [3.0, 2.0, 0.0, 0.0, 0.0]), # PSD -> slice tail + (lx.negative_semidefinite_tag, [-3.0, -2.0, 0.0, 0.0, 0.0]), # NSD -> head + ), +) +def test_hevd_truncation_branches(tag, eigvals): + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), eigvals) + solver = lx.HEVD() + plain = lx.MatrixLinearOperator(matrix, tag) + tagged = lx.MatrixLinearOperator(matrix, (tag, lx.MaxRankTag(2))) + + (w_t, v_t), _ = solver.init(tagged, {}) + assert w_t.shape == (2,) + assert v_t.shape == (5, 2) + + vector = jnp.arange(5.0) + 0.5 + x_plain = lx.linear_solve(plain, vector, solver).value + x_tagged = lx.linear_solve(tagged, vector, solver).value + assert jnp.allclose(x_plain, x_tagged) + + +@pytest.mark.parametrize( + "tag, eigvals", + ( + (lx.hermitian_tag, [3.0, -2.0, 1.5, 0.0, 0.0]), + (lx.positive_semidefinite_tag, [3.0, 2.0, 1.5, 0.0, 0.0]), + (lx.negative_semidefinite_tag, [-3.0, -2.0, -1.5, 0.0, 0.0]), + ), +) +def test_hevd_raises_when_max_rank_too_small(tag, eigvals): + # A genuinely rank-3 Hermitian matrix declared rank 2: truncation would discard an + # eigenvalue above the rcond threshold, so the solve must raise. + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), eigvals) + operator = lx.MatrixLinearOperator(matrix, (tag, lx.MaxRankTag(2))) + vector = jnp.arange(5.0) + 0.5 + with pytest.raises(eqx.EquinoxRuntimeError): + lx.linear_solve(operator, vector, lx.HEVD()) diff --git a/tests/test_operator.py b/tests/test_operator.py index d3c5eff..a4dc81d 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -249,6 +249,55 @@ def test_is_symmetric(dtype, getkey): _assert_except_diag(lx.is_symmetric, not_symmetric_operators, flip_cond=True) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_is_hermitian(dtype, getkey): + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + hermitian_operators = _setup(getkey, matrix + matrix.conj().T, lx.hermitian_tag) + for operator in hermitian_operators: + assert lx.is_hermitian(operator) + + not_hermitian_operators = _setup(getkey, matrix) + _assert_except_diag(lx.is_hermitian, not_hermitian_operators, flip_cond=True) + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_is_hermitian_implications(dtype, getkey): + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + real = jnp.issubdtype(dtype, jnp.floating) + + # PSD/NSD are Hermitian for both real and complex dtypes. + psd = matrix @ matrix.conj().T + assert lx.is_hermitian(lx.MatrixLinearOperator(psd, lx.positive_semidefinite_tag)) + assert lx.is_hermitian(lx.MatrixLinearOperator(-psd, lx.negative_semidefinite_tag)) + + # Symmetric (A = Aᵀ) and diagonal operators are Hermitian iff real-valued. + sym = lx.MatrixLinearOperator(matrix + matrix.T, lx.symmetric_tag) + assert lx.is_hermitian(sym) == real + assert lx.is_hermitian(lx.DiagonalLinearOperator(jnp.diag(matrix))) == real + + # Conversely a Hermitian operator is symmetric iff real-valued. + herm = lx.MatrixLinearOperator(matrix + matrix.conj().T, lx.hermitian_tag) + assert lx.is_hermitian(herm) + assert lx.is_symmetric(herm) == real + + +def test_hermitian_tag_propagation(getkey): + # Hermitian-ness is preserved through transpose and inversion, addition, and real + # (but not complex) scaling. + assert lx.hermitian_tag in lx.transpose_tags(frozenset({lx.hermitian_tag})) + assert lx.hermitian_tag in lx.invert_tags(frozenset({lx.hermitian_tag})) + + def herm_op(): + m = jr.normal(getkey(), (3, 3), dtype=jnp.complex128) + return lx.MatrixLinearOperator(m + m.conj().T, lx.hermitian_tag) + + op = herm_op() + assert lx.is_hermitian(op + herm_op()) # sum of Hermitian is Hermitian + assert lx.is_hermitian(-op) # negation preserves + assert lx.is_hermitian(op * 2.0) # real scaling preserves + assert not lx.is_hermitian(op * (1.0 + 1j)) # complex scaling does not + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 51b5821..c858762 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -60,3 +60,38 @@ def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ignore in_vec = [a(1.0), 2.0, 3.0] solver = lx.AutoLinearSolver(well_posed=False) assert_transpose_fixture(operator, out_vec, in_vec, solver) + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_H_property(getkey, dtype): + # `.H` is the conjugate transpose `conj(A)ᵀ`, but a no-op for Hermitian operators. + m = jr.normal(getkey(), (4, 3), dtype=dtype) + op = lx.MatrixLinearOperator(m) + assert op.H is not op + assert tree_allclose(op.H.as_matrix(), m.conj().T) + + herm = jr.normal(getkey(), (3, 3), dtype=dtype) + herm = herm + herm.conj().T + hop = lx.MatrixLinearOperator(herm, lx.hermitian_tag) + assert hop.H is hop + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +@pytest.mark.parametrize("solver_cls", (lx.HEVD, lx.Cholesky, lx.LU)) +def test_hermitian_adjoint_state_reuses_init(getkey, dtype, solver_cls): + # For a Hermitian operator (`Aᴴ = A`), `init(A.H) == init(A)`, so the linear-solve + # JVP reuses the original state as the adjoint state -- this holds for any solver. + m = jr.normal(getkey(), (4, 4), dtype=dtype) + m = m + m.conj().T + if solver_cls is lx.Cholesky: + m = m @ m.conj().T # PSD, still Hermitian + op = lx.MatrixLinearOperator(m, lx.positive_semidefinite_tag) + else: + op = lx.MatrixLinearOperator(m, lx.hermitian_tag) + solver = solver_cls() + assert op.H is op + state = solver.init(op, {}) + adjoint_state = solver.init(op.H, {}) + assert tree_allclose( + eqx.filter(state, eqx.is_array), eqx.filter(adjoint_state, eqx.is_array) + ) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 4686ce8..5a5e700 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -57,7 +57,7 @@ def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): def test_pytree_wellposed(solver, getkey, dtype): if not isinstance( solver, - (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG), + (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.HEVD, lx.CG), ): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10