feat: GPU overrides for GeneralUnitaryMatrices (exp!, log!, retract!, project!)#4
Merged
mateuszbaran merged 2 commits intoJuliaManifolds:mainfrom Mar 13, 2026
Conversation
… project!)
Add batched GPU-accelerated operations for PowerManifold{..., GeneralUnitaryMatrices}
covering Rotations, OrthogonalMatrices, and UnitaryMatrices.
New GPU operations:
- exp! via Scaling & Squaring matrix exponential (Real + Complex)
- log! via Inverse Scaling & Squaring with Denman-Beavers iteration
- retract! (PolarRetraction) via batched SVD
- project! for points (AbsoluteDeterminantOneMatrixType) and tangent vectors
Shared GPU primitives added to helpers.jl:
- _matrix_exp_gpu extended for Complex element types
- _matrix_log_gpu (Denman-Beavers square root + Taylor series)
- _batched_inv_gpu (LU via getrf/getri_strided_batched!)
- _batched_sqrtm_gpu (Denman-Beavers iteration)
Test suite restructured into test/jlarray/ and test/cuda/ directories.
Comprehensive tests for Float32/Float64, Real/Complex, all three manifold aliases.
Benchmark script added for Rotations and UnitaryMatrices.
Member
|
Nice, good work. The design here looks right here 👍 . |
Member
|
Nice work. I think in this small package we can also make the GPU packages hard dependencies and not use extensions? |
Member
|
Using extensions has the benefit of not having to load the GPU packages like CUDA on unsupported platforms. We only load what's needed. It will be relevant when we support multiple different backends (for example ROCm, Metal). |
Member
|
Ok. If there are new methods introduced they could still be documented in the main file then for visibility. |
mateuszbaran
approved these changes
Mar 13, 2026
Member
|
Yes, documentation is important. I will try organizing the docs later. For now, this LGTM 👍 |
38 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds batched GPU-accelerated operations for
PowerManifold{..., <:GeneralUnitaryMatrices}, coveringRotations,OrthogonalMatrices, andUnitaryMatrices. This follows up on the feedback from PR #2 — only manifolds that genuinely need batched CUBLAS/cuSOLVER calls getPowerManifoldoverrides.What's implemented
New GPU operations (
ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl)exp!gemm_strided_batchedlog!(Real)gemm_strided_batched,getrf/getri_strided_batched!log!(Complex)retract!(Polar)p + p*tX→ batched SVD →U*V'gemm_strided_batched,gesvdj!project!(point)AbsoluteDeterminantOneMatrixTypeonly)gesvdj!,gemm_strided_batchedproject!(tangent)(p'X - X'p)/2gemm_strided_batchedPowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}— no field parameter constraint, soRotations,OrthogonalMatrices, andUnitaryMatricesall match.project!(point) is restricted toAbsoluteDeterminantOneMatrixType(notRotations, which needs determinant correction).Shared GPU primitives (
ext/ManifoldsGPUCUDAExt/helpers.jl)_matrix_exp_gpu(Complex)ComplexF32/ComplexF64_matrix_log_gpulog(I+X)_batched_inv_gpugetrf_strided_batched!+getri_strided_batched!)_batched_sqrtm_gpuAll helpers are general-purpose —
_matrix_exp_gpuis already shared withStiefel.jl.File structure
Tests
allowscalar(false)): 32 tests coveringexp!,retract!,project!for Rotations/UnitaryMatrices in Float32/Float64.log!is CUDA-only (requiresgetrf/getri).is_point) and CPU/GPU agreement.Benchmark results (RTX 3090, n=16, batch=2048, Float32)
log!speedup is lower due to the iterative Denman-Beavers square root (4×10 iterations with batched LU inversion each). Still a significant improvement over serial CPU eigendecomposition.Design notes
gesvdj!has a size limit in cuSOLVER — CPU fallback is provided for matrices exceeding it (wrapped intry/catch, intentionally outside AD paths).log!promotes to complex internally (unitary square roots may have complex entries), then takesreal().