Skip to content

feat: GPU overrides for Euclidean PowerManifold#2

Closed
zazabap wants to merge 1 commit intoJuliaManifolds:mainfrom
zazabap:feat/euclidean-gpu-v3
Closed

feat: GPU overrides for Euclidean PowerManifold#2
zazabap wants to merge 1 commit intoJuliaManifolds:mainfrom
zazabap:feat/euclidean-gpu-v3

Conversation

@zazabap
Copy link
Copy Markdown
Contributor

@zazabap zazabap commented Mar 6, 2026

Summary

Add 11 GPU-native overrides for PowerManifold(Euclidean) on CuArray, plus test/benchmark restructuring.

GPU Overrides

Each override replaces ManifoldsBase's default for i in get_iterator(M) per-element loop with fused GPU broadcasting, avoiding scalar indexing on CuArray:

Function GPU Strategy Why override is needed
exp! q .= p .+ X Default loops per power element
log! X .= q .- p Default loops per power element
distance sqrt(sum((p .- q) .^ 2)) Default calls log! + norm in loop
inner sum(X .* Y) Default loops via _inner_r helper
norm sqrt(sum(X .^ 2)) Default loops via _norm_r (independent of inner)
parallel_transport_to! Y .= X Default loops per element
project! (point) q .= p Default loops per element
project! (vector) Y .= X Default loops per element
zero_vector! X .= zero(T) Default loops per element
mid_point! q .= (p1 .+ p2) ./ 2 Default loops per element
vector_transport_to! Y .= X Default loops per element

All 11 functions have independent per-element loops in ManifoldsBase's PowerManifold

Test Restructuring

  • Split tests into test/jlarray/ and test/cuda/ subdirectories with per-manifold files
  • runtests.jl dispatches JLArray tests always, CUDA tests only when CUDA.functional()
  • Added JLArray Stiefel tests for CI without GPU hardware
  • Both Float32 and Float64 coverage

Benchmark Restructuring

  • Extracted shared helpers into benchmarks/common.jl
  • Renamed benchmarks/main.jlbenchmarks/Stiefel.jl
  • Added benchmarks/Euclidean.jl covering all 11 overrides

Benchmark Results (RTX 3090, batch=2048, n=32, k=16)

Operation CPU (ms) GPU (ms) Speedup
exp! 0.458 0.042 10.9×
log! 0.480 0.044 10.8×
distance 0.607 0.056 10.8×
inner 0.594 0.054 11.0×
norm 0.360 0.033 10.8×
parallel_transport_to! 0.470 0.044 10.7×
project! (point) 0.237 0.038 6.2×
project! (vector) 0.236 0.036 6.6×
zero_vector! 0.139 0.047 2.9×
mid_point! 0.340 0.038 8.9×
vector_transport_to! 0.479 0.043 11.2×

Test plan

  • All 77 tests pass (JLArray + CUDA)
  • Float32 and Float64 tested with correct tolerances
  • is_point manifold membership checks for exp! results
  • CPU/GPU agreement verified with isapprox
  • No scalar indexing in any GPU override
  • All 11 benchmarks show speedup

…ucture

Add 11 GPU-native overrides for PowerManifold(Euclidean) on CuArray:
exp!, log!, distance, inner, norm, parallel_transport_to!, project! (point),
project! (vector), zero_vector!, mid_point!, vector_transport_to!.

Each override replaces ManifoldsBase's default per-element loop
(for i in get_iterator(M)) with fused GPU broadcasting, avoiding
scalar indexing on CuArray.

Also:
- Split tests into test/jlarray/ and test/cuda/ subdirectories
- Add JLArray Stiefel tests for CI without GPU hardware
- Extract shared benchmark helpers into benchmarks/common.jl
- Rename benchmarks/main.jl to benchmarks/Stiefel.jl
- Add benchmarks/Euclidean.jl covering all 11 overrides
- Add JLArrays and GPUArrays as test dependencies
@zazabap
Copy link
Copy Markdown
Contributor Author

zazabap commented Mar 6, 2026

Some of the tests loose the CPU GPU difference to 1e-12 instead of 2e-14.

@mateuszbaran
Copy link
Copy Markdown
Member

Thanks for continuing the work here. First, I think the CUDA/JLArray test split is a good idea. However, if you need to use GPUArrays.allowscalar(true), then it means that the test effectively fails. I think JLArray will mostly be useful for tracking any possible regressions in generic Manifolds.jl code that doesn't have CUDA overrides. If we have a CuArray-specific implementation, we likely can't test it with JLArray. So please only add JLArray test that pass with GPUArrays.allowscalar(false). So for example we could have some JLArray tests with plain Stiefel.

Next, I think we should be quite restrictive about which manifolds have PowerManifold{ℝ, TM, <:Tuple, ArrayPowerRepresentation} overrides. For example power manifold of Euclidean is the same as just Euclidean with more indices. There is nothing new introduced but the code complexity grows, which makes it harder to maintain. Let's limit these overrides to what is actually useful see here: JuliaManifolds/Manifolds.jl#856 (comment) . For other manifolds we can have a table in docs explaining how to obtain the functionality. For example power manifold of Euclidean is still Euclidean, power manifold of Sphere (in the one-index case) is Oblique.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants