Skip to content

fix iterative breakdowns due to element-wise tolerances#230

Open
jpbrodrick89 wants to merge 1 commit into
devfrom
jpb/scalar-convergence-tol
Open

fix iterative breakdowns due to element-wise tolerances#230
jpbrodrick89 wants to merge 1 commit into
devfrom
jpb/scalar-convergence-tol

Conversation

@jpbrodrick89

@jpbrodrick89 jpbrodrick89 commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

A flaky GMRES test on #228 led me down a rabbit hole to realise that all iterative solvers except for LSMR have been using nonlinear solver style element-wise convergence checks (e.g. Hairer–Nørsett–Wanner) to handle elements with hugely differing scale, i.e.:

b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω)
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)

However, for iterative LINEAR solvers RHS elements of different scales will inevitably mix unless the matrix is sufficiently sparse leading to round-off error that to "iterative breakdowns". These cases can crop up randomly in the tests.

MWE:

A = jnp.array([[1.0, 0.5, 0.3],
               [0.2, 2.0, 0.7],
               [0.6, 0.1, 1.5]])
b = jnp.array([1e8, 1.0, 1.0])

lx.linear_solve(lx.MatrixLinearOperator(A), b, lx.GMRES(rtol=1e-12, atol=1e-12))
# -> equinox.EquinoxRuntimeError: A form of iterative breakdown has occurred ...

# ...even though the system is solved essentially perfectly:

sol = lx.linear_solve(lx.MatrixLinearOperator(A), b,
                      lx.GMRES(rtol=1e-12, atol=1e-12), throw=False)
print(jnp.linalg.norm(A @ sol.value - b) / jnp.linalg.norm(b))   # 1.7e-16
print(jnp.allclose(sol.value, jnp.linalg.solve(A, b)))           # True

Note that the same solve passes at rtol=atol=1e-6 — the bug only bites when the requested tolerance dips near the achievable round-off floor.

This is fixed by adopting the standard convention (e.g. as in scipy) to only compare the norm of the vector to the tolerance:

b_scale = self.atol + self.rtol * self.norm(vector)
y_scale = self.atol + self.rtol * self.norm(y)
return (self.norm(r) > b_scale) | (self.norm(diff) > y_scale)

@patrick-kidger this has been hugely long-standing (introduced in "initial commit") and I understand you might have some affinity towards the original approach. As such, I will await your oversight before merging this.

@patrick-kidger

Copy link
Copy Markdown
Owner

So this is actually an explicit choice, and we have some discussion on this in Diffrax here. From the discussion there, I'd actually have thought that the current formulation should be more stable when considering components with different scales?

Basically, if we ever do norm(a) / norm(b) then we'll end up with a weaker convergence condition – we just wash away all the small scales and only really care if the big ones have converged. That doesn't seem particularly fair to the small scales.

It sounds like your MWE represents a case where the small scales are now being overrepresented instead? Perhaps due to numerical stability/precision issues?

(Side note, one of the things we did try to do across Lineax/Optimistix/Diffrax was standardise convergence criteria on the one we have here, as broadly speaking the rest of the numerical literature just does in some ad-hoc problem-specific way every time.)

@jpbrodrick89

Copy link
Copy Markdown
Collaborator Author

Perhaps due to numerical stability/precision issues?

Yes, it is entirely because of numerical precision/round-off error, I believe GMRES is guaranteed to converge in n iterations with a tolerance of 0 if we have infinite precision.

The point is that is A.mv couples components of different scales. This sets a tolerance floor of order eps * cond(A) * max(b) (this is very approximate, the true floor is probably better expressed the way LSMR handles it).

I would argue that iterative linear solvers only ever make theoretical guarantees about convergence in the Euclidean norm (this is the "minimum residual" referred to in GMRES) and element-wise covergence is an unrealistic expectation. On the other hand failure of an ODE nonlinear solve allows a resolution through adaptive stepping (it is essentially a signal that a problem is too "difficult" to solve in some nonlinear sense and dropping the timestep will make it more linear, if a problem is already linear then no scaling would make an iterative solver behave any better).

That said if this is too much of a paradigm shift we can just fix the flaky tests by patching this into atol (i.e. atol=epsmax_condmax(b)~1e-13*max(b) which turns to be 1e-8/1e-9 when using calculating a jvp (as max(t_matrix@b) scales as cond(t) max(b)) the current tol is probably fine for the primal solve. So in summary, reducing atol by ~1e3-4 for test_jvp and another factor again for test_jvp_jvp wil probably cut it but in my opinion its a bit awkward.

@patrick-kidger

Copy link
Copy Markdown
Owner

I would argue that iterative linear solvers only ever make theoretical guarantees about convergence in the Euclidean norm (this is the "minimum residual" referred to in GMRES)

I think that's a reasonable position to take. In theory convergence in one norm guarantess convergence in all others, but as a practical matter that's not really how things work :D

I also like your observation that generally speaking failures become a matter of "can we make this problem more linear".


One other guiding principle that went into this, btw, is to ask what would happen if we pad our problems with extra zeros. Ideally the result would be identical. (And this principle is also the reason that Lineax unifies the finding of inverses and pseudoinverses, e.g. consider adding extra zeros to both operator and vector in a well-posed problem: the resulting pseudoinverse solution is just the similarly-padded inverse solution of the original problem.)

It's for this reason that we use norm=max_norm pretty much ubiquitously, and not two_norm.
(Although checking this over I've just spotted that LSMR uses two_norm, that snuck by me... so whatever we decide here we should make sure it's consistent over there too.)

Do you think there is any way we could keep your "make this problem more linear' intuition, whilst preserving this kind of zero-padding property?

@jpbrodrick89

Copy link
Copy Markdown
Collaborator Author

I'm not sure I follow, I think the result remains identical under zero padding. I'm not proposing to change norm but I think both max_norm and two_norm have this property (this is because both norms themselves are invariant under zero padding). The thing that this PR DOES affect is behaviour under block diagonal padding. After this PR we would likely terminate sooner before one of the block diagonals has converged because we don't know a priori there is no mixing that would imply we COULDN'T converge numerically.

@jpbrodrick89

Copy link
Copy Markdown
Collaborator Author

Following up from our conversation on Monday, I don't there exists a general norm that would give us the same results as before that we could offer as an alternative. If we want to allow users to recover the original we could offer an elementwise_rtol=None option (we assume they would want to use the same atol elementwise and globally even though that would mean atol is more lax in the elementwise setting than the norm-based one). To reiterate, I think its safe to assume that repeated mv's will essentially mix all elements of the vector rather than try come up with janky tag-based exceptions to this until we eventually identify an elegant sparse/block operator design.

@jpbrodrick89

Copy link
Copy Markdown
Collaborator Author

I think I've worked out the "elementwise" norm we desired that would recover the current behaviour is just a weighted norm with weights that depend on y/vector. So one solution is to allow/enforce a second argument on our norm functions, e.g.

def elementwise_norm(x, args):
    atol, rtol, ref = args
    return max_norm((x**ω / (atol + rtol * ω(ref).call(jnp.abs)).ω)

We would need to refactor all our existing norms to take these dummy args.

Alternatively one could hardcode some/all of the args but that's probably less flexible.

WDYT? Default to standard max_norm as this PR suggests and add the dummy args and elementwise_norm functionality to allow recovery of previous behaviour?

@jpbrodrick89 jpbrodrick89 added the bug Something isn't working label Jun 23, 2026
@patrick-kidger

Copy link
Copy Markdown
Owner

Ah, interesting – you're saying that you've found an approach that just adjusts the choice of reference used with the rtol?
IIUC that wouldn't actually mean changing the norms, it would just mean adjusting a couple of lines like this one:

b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω

so that we use something other than vector there.

What would you change it to?


One point that's not yet clear to me: we speculated that the current approach goes wrong solely due to floating-point issues (since in the infinite-precision limit one expects convergence in a finite number of iterations). Were you able to identify whether that is indeed the case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants