Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions lineax/_solver/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compute(
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
b_scale = self.atol + self.rtol * self.norm(vector)

# This implementation is the same a jax.scipy.sparse.linalg.bicgstab
# but with AbstractLinearOperator.
Expand All @@ -117,11 +117,16 @@ def not_converged(r, diff, y):
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
# Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖
# (and likewise for the increment in the `y` space). Note this uses
# scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the
# latter is unsatisfiable for wide-dynamic-range `b`, where the
# round-off floor of large components exceeds the absolute tolerance
# demanded of small ones, yielding spurious non-convergence.
y_scale = self.atol + self.rtol * self.norm(y)
b_unconverged = self.norm(r) > b_scale # pyright: ignore
y_unconverged = self.norm(diff) > y_scale
return b_unconverged | y_unconverged
else:
return True

Expand Down
18 changes: 11 additions & 7 deletions lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, TypeAlias

import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down Expand Up @@ -144,18 +143,23 @@ def compute(
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
b_scale = self.atol + self.rtol * self.norm(vector)

def not_converged(r, diff, y):
# The primary tolerance check.
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
# Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖
# (and likewise for the increment in the `y` space). Note this uses
# scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the
# latter is unsatisfiable for wide-dynamic-range `b`, where the
# round-off floor of large components exceeds the absolute tolerance
# demanded of small ones, yielding spurious non-convergence.
y_scale = self.atol + self.rtol * self.norm(y)
b_unconverged = self.norm(r) > b_scale # pyright: ignore
y_unconverged = self.norm(diff) > y_scale
return b_unconverged | y_unconverged
else:
return True

Expand Down
17 changes: 11 additions & 6 deletions lineax/_solver/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compute(
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
b_scale = self.atol + self.rtol * self.norm(vector)
operator = state
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
leaves, _ = jtu.tree_flatten(vector)
Expand All @@ -132,11 +132,16 @@ def not_converged(r, diff, y):
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
# Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖
# (and likewise for the increment in the `y` space). Note this uses
# scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the
# latter is unsatisfiable for wide-dynamic-range `b`, where the
# round-off floor of large components exceeds the absolute tolerance
# demanded of small ones, yielding spurious non-convergence.
y_scale = self.atol + self.rtol * self.norm(y)
b_unconverged = self.norm(r) > b_scale # pyright: ignore
y_unconverged = self.norm(diff) > y_scale
return b_unconverged | y_unconverged
else:
return True

Expand Down
Loading