Skip to content

feat: GPU overrides for GeneralUnitaryMatrices (exp!, log!, retract!, project!)#4

Merged
mateuszbaran merged 2 commits intoJuliaManifolds:mainfrom
zazabap:feat/general-unitary-gpu
Mar 13, 2026
Merged

feat: GPU overrides for GeneralUnitaryMatrices (exp!, log!, retract!, project!)#4
mateuszbaran merged 2 commits intoJuliaManifolds:mainfrom
zazabap:feat/general-unitary-gpu

Conversation

@zazabap
Copy link
Copy Markdown
Contributor

@zazabap zazabap commented Mar 11, 2026

Summary

Adds batched GPU-accelerated operations for PowerManifold{..., <:GeneralUnitaryMatrices}, covering Rotations, OrthogonalMatrices, and UnitaryMatrices. This follows up on the feedback from PR #2 — only manifolds that genuinely need batched CUBLAS/cuSOLVER calls get PowerManifold overrides.

What's implemented

New GPU operations (ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl)

Function Strategy CUDA primitives
exp! Scaling & Squaring matrix exponential gemm_strided_batched
log! (Real) Inverse Scaling & Squaring (Denman-Beavers sqrtm + Taylor series) gemm_strided_batched, getrf/getri_strided_batched!
log! (Complex) Same, with conjugate-transpose skew-Hermitian projection Same
retract! (Polar) p + p*tX → batched SVD → U*V' gemm_strided_batched, gesvdj!
project! (point) Batched SVD polar factor (AbsoluteDeterminantOneMatrixType only) gesvdj!, gemm_strided_batched
project! (tangent) Skew-symmetric projection (p'X - X'p)/2 gemm_strided_batched
  • Dispatches on PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation} — no field parameter constraint, so Rotations, OrthogonalMatrices, and UnitaryMatrices all match.
  • project! (point) is restricted to AbsoluteDeterminantOneMatrixType (not Rotations, which needs determinant correction).

Shared GPU primitives (ext/ManifoldsGPUCUDAExt/helpers.jl)

Function Purpose
_matrix_exp_gpu (Complex) Extends existing Real implementation to ComplexF32/ComplexF64
_matrix_log_gpu Inverse Scaling & Squaring: 4 repeated Denman-Beavers square roots + 16-term Taylor log(I+X)
_batched_inv_gpu Batched matrix inverse via LU (getrf_strided_batched! + getri_strided_batched!)
_batched_sqrtm_gpu Denman-Beavers iteration for batched matrix square root

All helpers are general-purpose — _matrix_exp_gpu is already shared with Stiefel.jl.

File structure

ext/ManifoldsGPUCUDAExt/
├── ManifoldsGPUCUDAExt.jl              # +1 line: include GeneralUnitaryMatrices
├── helpers.jl                          # +127 lines: complex exp, log, inv, sqrtm
├── GeneralUnitaryMatrices.jl           # +128 lines: 7 method overrides
├── Stiefel.jl                          # unchanged
├── Grassmann.jl                        # unchanged

test/
├── runtests.jl                         # restructured: includes jlarray/ and cuda/
├── jlarray/
│   ├── test_stiefel.jl                 # extracted from old runtests.jl
│   └── test_general_unitary_matrices.jl  # new: 32 tests
└── cuda/
    ├── test_stiefel.jl                 # extracted from old runtests.jl
    └── test_general_unitary_matrices.jl  # new: 158 tests

benchmarks/
└── GeneralUnitaryMatrices.jl           # new: Rotations + UnitaryMatrices benchmarks

Tests

  • JLArray tests (CI-safe, allowscalar(false)): 32 tests covering exp!, retract!, project! for Rotations/UnitaryMatrices in Float32/Float64. log! is CUDA-only (requires getrf/getri).
  • CUDA tests: 158 tests covering all 7 methods across Rotations, OrthogonalMatrices, UnitaryMatrices in Float32/Float64. Validates manifold membership (is_point) and CPU/GPU agreement.
  • Existing Stiefel tests extracted into the same directory structure (no test logic changed).

Benchmark results (RTX 3090, n=16, batch=2048, Float32)

Manifold Operation CPU (ms) GPU (ms) Speedup
Rotations(16) exp! 26.2 1.4 18.2x
Rotations(16) log! 238.1 25.7 9.3x
Rotations(16) retract! (Polar) 72.6 3.5 20.8x
UnitaryMatrices(16) exp! 51.0 3.1 16.4x
UnitaryMatrices(16) log! 364.1 24.6 14.8x

log! speedup is lower due to the iterative Denman-Beavers square root (4×10 iterations with batched LU inversion each). Still a significant improvement over serial CPU eigendecomposition.

Design notes

  • gesvdj! has a size limit in cuSOLVER — CPU fallback is provided for matrices exceeding it (wrapped in try/catch, intentionally outside AD paths).
  • Real log! promotes to complex internally (unitary square roots may have complex entries), then takes real().
  • All operations stay fully on-device — no host↔device transfers in the hot path.

… project!)

Add batched GPU-accelerated operations for PowerManifold{..., GeneralUnitaryMatrices}
covering Rotations, OrthogonalMatrices, and UnitaryMatrices.

New GPU operations:
- exp! via Scaling & Squaring matrix exponential (Real + Complex)
- log! via Inverse Scaling & Squaring with Denman-Beavers iteration
- retract! (PolarRetraction) via batched SVD
- project! for points (AbsoluteDeterminantOneMatrixType) and tangent vectors

Shared GPU primitives added to helpers.jl:
- _matrix_exp_gpu extended for Complex element types
- _matrix_log_gpu (Denman-Beavers square root + Taylor series)
- _batched_inv_gpu (LU via getrf/getri_strided_batched!)
- _batched_sqrtm_gpu (Denman-Beavers iteration)

Test suite restructured into test/jlarray/ and test/cuda/ directories.
Comprehensive tests for Float32/Float64, Real/Complex, all three manifold aliases.

Benchmark script added for Rotations and UnitaryMatrices.
Comment thread test/cuda/test_general_unitary_matrices.jl Outdated
@mateuszbaran
Copy link
Copy Markdown
Member

Nice, good work. The design here looks right here 👍 .

@kellertuer
Copy link
Copy Markdown
Member

Nice work. I think in this small package we can also make the GPU packages hard dependencies and not use extensions?
Just an idea.

@mateuszbaran
Copy link
Copy Markdown
Member

Using extensions has the benefit of not having to load the GPU packages like CUDA on unsupported platforms. We only load what's needed. It will be relevant when we support multiple different backends (for example ROCm, Metal).

@kellertuer
Copy link
Copy Markdown
Member

Ok. If there are new methods introduced they could still be documented in the main file then for visibility.

@mateuszbaran
Copy link
Copy Markdown
Member

Yes, documentation is important. I will try organizing the docs later. For now, this LGTM 👍

@mateuszbaran mateuszbaran merged commit 47135bb into JuliaManifolds:main Mar 13, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants