Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,8 @@ Note that these do *not* inspect the values of the operator -- instead, they use

---

::: lineax.is_hermitian

---

::: lineax.max_rank
1 change: 1 addition & 0 deletions docs/api/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such tha
- mv
- as_matrix
- transpose
- H
- in_structure
- out_structure
- in_size
Expand Down
7 changes: 7 additions & 0 deletions docs/api/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ These are capable of solving ill-posed linear problems.

---

::: lineax.HEVD
options:
members:
- __init__

---

::: lineax.Normal
options:
members:
Expand Down
8 changes: 7 additions & 1 deletion docs/api/tags.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True

::: lineax.symmetric_tag

Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.)
Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) For real operators this coincides with `lineax.hermitian_tag`.

---

::: lineax.hermitian_tag

Marks that an operator is Hermitian (self-adjoint). (As a matrix, $A = A^*$, the conjugate transpose.) For real operators this coincides with `lineax.symmetric_tag`.

---

Expand Down
5 changes: 4 additions & 1 deletion lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
has_unit_diagonal as has_unit_diagonal,
IdentityLinearOperator as IdentityLinearOperator,
is_diagonal as is_diagonal,
is_hermitian as is_hermitian,
is_lower_triangular as is_lower_triangular,
is_negative_semidefinite as is_negative_semidefinite,
is_positive_semidefinite as is_positive_semidefinite,
Expand All @@ -49,16 +50,17 @@
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
AbstractLinearSolver as AbstractLinearSolver,
AutoLinearSolver as AutoLinearSolver,
invert as invert,
linear_solve as linear_solve,
)
from ._solver import (
AutoLinearSolver as AutoLinearSolver,
BiCGStab as BiCGStab,
CG as CG,
Cholesky as Cholesky,
Diagonal as Diagonal,
GMRES as GMRES,
HEVD as HEVD,
LSMR as LSMR,
LU as LU,
Normal as Normal,
Expand All @@ -70,6 +72,7 @@
)
from ._tags import (
diagonal_tag as diagonal_tag,
hermitian_tag as hermitian_tag,
invert_tags as invert_tags,
invert_tags_rules as invert_tags_rules,
lower_triangular_tag as lower_triangular_tag,
Expand Down
1 change: 1 addition & 0 deletions lineax/_operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
diagonal as diagonal,
has_unit_diagonal as has_unit_diagonal,
is_diagonal as is_diagonal,
is_hermitian as is_hermitian,
is_lower_triangular as is_lower_triangular,
is_negative_semidefinite as is_negative_semidefinite,
is_positive_semidefinite as is_positive_semidefinite,
Expand Down
54 changes: 54 additions & 0 deletions lineax/_operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class AbstractLinearOperator(eqx.Module):
def __check_init__(self):
if (
is_symmetric(self)
or is_hermitian(self)
or is_positive_semidefinite(self)
or is_negative_semidefinite(self)
):
Expand Down Expand Up @@ -197,6 +198,19 @@ def T(self) -> "AbstractLinearOperator":
"""Equivalent to [`lineax.AbstractLinearOperator.transpose`][]"""
return self.transpose()

@property
def H(self) -> "AbstractLinearOperator":
"""The conjugate transpose (Hermitian adjoint) `Aᴴ` of this operator.

Equivalent to `lineax.conj(operator).transpose()`, except that for a Hermitian
operator -- for which `Aᴴ = A` -- this is a no-op and returns `self` unchanged.
(As with [`lineax.is_hermitian`][], only the tag is checked, not the actual
values of the operator.)
"""
if is_hermitian(self):
return self
return conj(self).transpose()

def __add__(self, other) -> "AbstractLinearOperator":
# Local imports to avoid a circular dependency: `binary`/`wrapper` import
# `base`, so the operators built here are imported lazily at call time.
Expand Down Expand Up @@ -399,6 +413,20 @@ def tridiagonal(
_default_not_implemented("tridiagonal", operator)


def has_real_dtype(operator) -> bool:
"""Check if all dtypes in an operator's structure are real (not complex)."""
leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure()))
dtype = jnp.result_type(*leaves)
if jnp.issubdtype(dtype, jnp.complexfloating):
return False
elif jnp.issubdtype(dtype, jnp.floating):
return True
else:
assert False, (
"Only `jnp.floating` and `jnp.complexfloating` dtypes are understood."
)


@ft.singledispatch
def is_symmetric(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as symmetric.
Expand All @@ -417,6 +445,32 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool:
_default_not_implemented("is_symmetric", operator)


@ft.singledispatch
def is_hermitian(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as Hermitian (self-adjoint).

See [the documentation on linear operator tags](../api/tags.md) for more
information.

**Arguments:**

- `operator`: a linear operator.

**Returns:**

Either `True` or `False.`
"""
# Default for operators that don't register `is_hermitian` explicitly (e.g. custom
# `AbstractLinearOperator`s written before it existed): derive it from the other
# property checks, the same way the built-in operators do minus the
# `hermitian_tag` check, which needs operator-specific `.tags`.
if is_positive_semidefinite(operator) or is_negative_semidefinite(operator):
return True
if has_real_dtype(operator) and is_symmetric(operator):
return True
return False


@ft.singledispatch
def is_diagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as diagonal.
Expand Down
14 changes: 14 additions & 0 deletions lineax/_operator/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
AbstractLinearOperator,
conj,
diagonal,
has_real_dtype,
has_unit_diagonal,
is_diagonal,
is_hermitian,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
Expand Down Expand Up @@ -203,6 +205,7 @@ def _(operator):

for check in (
is_symmetric,
is_hermitian,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
Expand Down Expand Up @@ -247,6 +250,17 @@ def _(operator):
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)


# is_hermitian: as above, diagonal matrices commute. A product of diagonals is itself
# diagonal, which is Hermitian only when its (complex) entries are real-valued.
@is_hermitian.register(ComposedLinearOperator)
def _(operator):
return (
is_diagonal(operator.operator1)
and is_diagonal(operator.operator2)
and has_real_dtype(operator)
)


# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
@is_tridiagonal.register(ComposedLinearOperator)
Expand Down
38 changes: 22 additions & 16 deletions lineax/_operator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from .._tags import (
diagonal_tag,
hermitian_tag,
lower_triangular_tag,
negative_semidefinite_tag,
positive_semidefinite_tag,
Expand All @@ -53,9 +54,11 @@
conj,
diagonal,
FlatPyTree,
has_real_dtype,
has_unit_diagonal,
inexact_structure,
is_diagonal,
is_hermitian,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
Expand Down Expand Up @@ -724,20 +727,6 @@ def _(operator):
# checks


def _has_real_dtype(operator) -> bool:
"""Check if all dtypes in an operator's structure are real (not complex)."""
leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure()))
dtype = jnp.result_type(*leaves)
if jnp.issubdtype(dtype, jnp.complexfloating):
return False
elif jnp.issubdtype(dtype, jnp.floating):
return True
else:
assert False, (
"Only `jnp.floating` and `jnp.complexfloating` dtypes are understood."
)


@is_symmetric.register(MatrixLinearOperator)
@is_symmetric.register(PyTreeLinearOperator)
@is_symmetric.register(JacobianLinearOperator)
Expand All @@ -746,12 +735,29 @@ def _(operator):
# Symmetric (A = A^T) if explicitly tagged symmetric or diagonal
if symmetric_tag in operator.tags or diagonal_tag in operator.tags:
return True
# PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian
# PSD/NSD/Hermitian imply A = A^T only for real dtypes
if (
positive_semidefinite_tag in operator.tags
or negative_semidefinite_tag in operator.tags
or hermitian_tag in operator.tags
):
return has_real_dtype(operator)
return False


@is_hermitian.register(MatrixLinearOperator)
@is_hermitian.register(PyTreeLinearOperator)
@is_hermitian.register(JacobianLinearOperator)
@is_hermitian.register(FunctionLinearOperator)
def _(operator):
if (
hermitian_tag in operator.tags
or positive_semidefinite_tag in operator.tags
or negative_semidefinite_tag in operator.tags
):
return _has_real_dtype(operator)
return True
if symmetric_tag in operator.tags or diagonal_tag in operator.tags:
return has_real_dtype(operator)
return False


Expand Down
9 changes: 9 additions & 0 deletions lineax/_operator/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
conj,
diagonal,
FlatPyTree,
has_real_dtype,
has_unit_diagonal,
inexact_structure,
is_diagonal,
is_hermitian,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
Expand Down Expand Up @@ -292,6 +294,7 @@ def _(operator):


@is_symmetric.register(IdentityLinearOperator)
@is_hermitian.register(IdentityLinearOperator)
def _(operator):
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True

Expand All @@ -301,7 +304,13 @@ def _(operator):
return True


@is_hermitian.register(DiagonalLinearOperator)
def _(operator):
return has_real_dtype(operator)


@is_symmetric.register(TridiagonalLinearOperator)
@is_hermitian.register(TridiagonalLinearOperator)
def _(operator):
return False

Expand Down
46 changes: 46 additions & 0 deletions lineax/_operator/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .._tags import (
diagonal_tag,
hermitian_tag,
lower_triangular_tag,
MaxRankTag,
negative_semidefinite_tag,
Expand All @@ -43,8 +44,10 @@
as_frozenset,
conj,
diagonal,
has_real_dtype,
has_unit_diagonal,
is_diagonal,
is_hermitian,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
Expand Down Expand Up @@ -340,6 +343,7 @@ def _(operator):

for check in (
is_symmetric,
is_hermitian,
is_diagonal,
has_unit_diagonal,
is_lower_triangular,
Expand Down Expand Up @@ -371,6 +375,28 @@ def _(operator, check=check):
return check(operator.operator)


def _scalar_is_real(scalar) -> bool:
"""Whether a scalar is statically known to be real-valued.

A real dtype guarantees a real value, so this is known even for JAX tracers (whose
runtime value is unknown at trace time): only the dtype matters, not the value.
Returns `False` only for genuinely complex-typed scalars.
"""
return not jnp.issubdtype(jnp.result_type(scalar), jnp.complexfloating)


# Hermitian-ness preserved by negation and scaling by any real scalar
@is_hermitian.register(NegLinearOperator)
def _(operator):
return is_hermitian(operator.operator)


@is_hermitian.register(MulLinearOperator)
@is_hermitian.register(DivLinearOperator)
def _(operator):
return _scalar_is_real(operator.scalar) and is_hermitian(operator.operator)


# has_unit_diagonal is NOT preserved by negation
@has_unit_diagonal.register(NegLinearOperator)
def _(operator):
Expand Down Expand Up @@ -498,6 +524,26 @@ def _(operator, check=check, tag=tag):
return (tag in operator.tags) or check(operator.operator)


# `is_hermitian` is special-cased rather than handled by the loop above: a tag other
# than `hermitian_tag` can still imply Hermitian-ness. PSD/NSD operators are Hermitian
# (real or complex), and real symmetric/diagonal operators are Hermitian too. This
# mirrors the cross-implications encoded for the core operators.
@is_hermitian.register(TaggedLinearOperator)
def _(operator):
tags = operator.tags
if is_hermitian(operator.operator):
return True
if (
hermitian_tag in tags
or positive_semidefinite_tag in tags
or negative_semidefinite_tag in tags
):
return True
if symmetric_tag in tags or diagonal_tag in tags:
return has_real_dtype(operator)
return False


@max_rank.register(TaggedLinearOperator)
def _(operator):
inner = max_rank(operator.operator)
Expand Down
Loading
Loading