Skip to content

Add pivoted qr solver#210

Open
adconner wants to merge 2 commits into
patrick-kidger:mainfrom
adconner:push-slkzlznovwro
Open

Add pivoted qr solver#210
adconner wants to merge 2 commits into
patrick-kidger:mainfrom
adconner:push-slkzlznovwro

Conversation

@adconner

@adconner adconner commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

Let me know your opinions about the design. Most opinionated is the static vs dynamic rank modes, similar to @jpbrodrick89's suggestion #201 (comment). In my design, the user asserts static rank equality, rather than inequality, which I think might be useful in the contexts in which it arises (ie, we are optimizing in the locus of rank r matrices, if we get close to the locus of rank r-1, we might not want the structure of the calculation to suddenly change).

Testing this correctly required some test restructuring (as well as bug fixes).

There is a jax bug in the jvp rule for pivoted qr in the presence of vmapping, so there are test failures which are resolved by jax-ml/jax#35586. Actually the fact that the tests fail due to a bug in jax qr jvp is a bit mysterious to me because my understanding is that lineax solver jvp should never call the jax qr jvp. I haven't yet investigated further, but I welcome insight into this.

Resolves #141 and #201.

@tttc3

tttc3 commented Mar 4, 2026

Copy link
Copy Markdown

I've updated the title on #209 to help clarify how our approaches differ. I'd be really keen to get your thoughts!

Testing this correctly required some test restructuring (as well as bug fixes).
@adconner adconner force-pushed the push-slkzlznovwro branch from a1c854d to 2fc18dc Compare March 9, 2026 16:28
@adconner

adconner commented Mar 9, 2026

Copy link
Copy Markdown
Contributor Author

I stacked this change atop #212, which fixes the test failures by jax's broken qr jvp.

@jpbrodrick89

jpbrodrick89 commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

Hi @adconner @tttc3 sorry for taking so long to comment on this. Thank you both so much for the contributions! I've been deliberating on the best design for a while now. This is where my current thoughts sit:

  1. I prefer the secondary QR vs Cholesky as it should be more numerically stable and better mirrors dgelsy.
  2. We shouldn't wait for a jax wrapping of dtzrzf even though its a LOT more efficient for nearly full rank systems. Beyond potential dynamic sizing restrictions there isn't a good GPU implementation I'm aware of.
  3. I don't feel too strongly about whether we extend the existing QR class or create a new class. There is precendent in solvers conditionally assuming full rank (both DiagonalLinearSolver and' AutoLinearSolver fall into this category), but the differences in state and degree of complexity MIGHT warrant separation (personally I lean towards extending QR as there might be more shared logic between other functionality I'm working on such as slogdet and variable projection). If we extend QR my preferred opt-in would pivoted=True to mirror scipy or well_posed=False to mirror other solvers (I personally prefer the former)band instead of supprting a union of states we make state be a fixed struct but with jpvt and rz being None if pivoting is False. If we create a new class I'd lean towards calling it PivotedQR or RRQR.
  4. I'd prefer all (possibly) rank-deficient solvers to have as consistent an API as possible unless there is something actually unique about the solver (see below). For example, while allowing users to opt into error raising if rank is less than expected might indeed be a desirable feature, I don't see why we would want QR to do that and not SVD or DiagonalLinearSolver. Right now the common API is based on the max rank tag but if you think we should consider alternative/additional config they should live as their own separate PR/issue either before or after (willing to be open minded about this).
  5. The one thing that IS unique about pivoted QR is that the amount of extra work required for the orthogonal decomposition is very high (O(max_rank(op)^3) in these implementations. So it is tempting to support a fast path for the full rank case such @tttc3's lax.cond. For example, if one is solving a nonlinear least squares problem with a system that is USUALLY full rank but can sometimes hit regions of rank deficiency, the cost of lax.cond's extra compile cost and lack of CSE might be worth the compute saving as long as one is not vmapping (and both branches might be run). As such, it might be tempting to add a use_cond option. @patrick-kidger I will defer to you here given your deep experience of how dangerous it might be to use lax.cond in a nonlinear solve loop as well as how best to manage the potentially footgunnable API for this one. My lean is that as long as full orthogonal decomposition is faster than SVD, having no lax.cond is fine until someone complains about wanting to go even faster.
  6. Note I will shortly be putting in a PR for hermite eigenvalue decomposition which will provide a faster solve for square Hermitian rank-deficient systems. We will have to carefully benchmark this against pivoted QR to decide the dispatch priority in AutoLinearSolver (and also be careful that pivoted QR may even be worse than SVD on GPU especially if a user doesn't have MAGMA).
  7. Note we now use ormqr in QR so should do the same here.

Let me know if any other questions, very happy to discuss further. I think this will be a very valuable and well used feature once we've ironed out the complexities. 😀

@jpbrodrick89 jpbrodrick89 added the feature New feature label Jun 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Utilise pivoted QR functionality added in JAX v0.5.1

3 participants