From 22d4042efa2e114b56d7ce0ec20cd9802b777d77 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 5 Jun 2026 01:42:23 +0100 Subject: [PATCH] fix iterative breakdowns due to element-wise tolerances --- lineax/_solver/bicgstab.py | 17 +++++++++++------ lineax/_solver/cg.py | 18 +++++++++++------- lineax/_solver/gmres.py | 17 +++++++++++------ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 31cf241d..26e9164c 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -93,7 +93,7 @@ def compute( and self.rtol == 0 ) if has_scale: - b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω + b_scale = self.atol + self.rtol * self.norm(vector) # This implementation is the same a jax.scipy.sparse.linalg.bicgstab # but with AbstractLinearOperator. @@ -117,11 +117,16 @@ def not_converged(r, diff, y): # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: - with jax.numpy_dtype_promotion("standard"): - y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω - norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore - norm2 = self.norm((diff**ω / y_scale**ω).ω) - return (norm1 > 1) | (norm2 > 1) + # Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖ + # (and likewise for the increment in the `y` space). Note this uses + # scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the + # latter is unsatisfiable for wide-dynamic-range `b`, where the + # round-off floor of large components exceeds the absolute tolerance + # demanded of small ones, yielding spurious non-convergence. + y_scale = self.atol + self.rtol * self.norm(y) + b_unconverged = self.norm(r) > b_scale # pyright: ignore + y_unconverged = self.norm(diff) > y_scale + return b_unconverged | y_unconverged else: return True diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index eee8c970..a70d0156 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -17,7 +17,6 @@ from typing import Any, TypeAlias import equinox.internal as eqxi -import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -144,18 +143,23 @@ def compute( and self.rtol == 0 ) if has_scale: - b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω + b_scale = self.atol + self.rtol * self.norm(vector) def not_converged(r, diff, y): # The primary tolerance check. # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: - with jax.numpy_dtype_promotion("standard"): - y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω - norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore - norm2 = self.norm((diff**ω / y_scale**ω).ω) - return (norm1 > 1) | (norm2 > 1) + # Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖ + # (and likewise for the increment in the `y` space). Note this uses + # scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the + # latter is unsatisfiable for wide-dynamic-range `b`, where the + # round-off floor of large components exceeds the absolute tolerance + # demanded of small ones, yielding spurious non-convergence. + y_scale = self.atol + self.rtol * self.norm(y) + b_unconverged = self.norm(r) > b_scale # pyright: ignore + y_unconverged = self.norm(diff) > y_scale + return b_unconverged | y_unconverged else: return True diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index d5911a06..0b719d59 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -116,7 +116,7 @@ def compute( and self.rtol == 0 ) if has_scale: - b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω + b_scale = self.atol + self.rtol * self.norm(vector) operator = state preconditioner, y0 = preconditioner_and_y0(operator, vector, options) leaves, _ = jtu.tree_flatten(vector) @@ -132,11 +132,16 @@ def not_converged(r, diff, y): # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: - with jax.numpy_dtype_promotion("standard"): - y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω - norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore - norm2 = self.norm((diff**ω / y_scale**ω).ω) - return (norm1 > 1) | (norm2 > 1) + # Standard relative-residual stopping rule: ‖r‖ ≤ atol + rtol·‖b‖ + # (and likewise for the increment in the `y` space). Note this uses + # scalar norms, *not* an elementwise `atol + rtol·|b|` scale: the + # latter is unsatisfiable for wide-dynamic-range `b`, where the + # round-off floor of large components exceeds the absolute tolerance + # demanded of small ones, yielding spurious non-convergence. + y_scale = self.atol + self.rtol * self.norm(y) + b_unconverged = self.norm(r) > b_scale # pyright: ignore + y_unconverged = self.norm(diff) > y_scale + return b_unconverged | y_unconverged else: return True