feat: GPU overrides for Euclidean PowerManifold#2
feat: GPU overrides for Euclidean PowerManifold#2zazabap wants to merge 1 commit intoJuliaManifolds:mainfrom
Conversation
…ucture Add 11 GPU-native overrides for PowerManifold(Euclidean) on CuArray: exp!, log!, distance, inner, norm, parallel_transport_to!, project! (point), project! (vector), zero_vector!, mid_point!, vector_transport_to!. Each override replaces ManifoldsBase's default per-element loop (for i in get_iterator(M)) with fused GPU broadcasting, avoiding scalar indexing on CuArray. Also: - Split tests into test/jlarray/ and test/cuda/ subdirectories - Add JLArray Stiefel tests for CI without GPU hardware - Extract shared benchmark helpers into benchmarks/common.jl - Rename benchmarks/main.jl to benchmarks/Stiefel.jl - Add benchmarks/Euclidean.jl covering all 11 overrides - Add JLArrays and GPUArrays as test dependencies
|
Some of the tests loose the CPU GPU difference to 1e-12 instead of 2e-14. |
|
Thanks for continuing the work here. First, I think the CUDA/JLArray test split is a good idea. However, if you need to use Next, I think we should be quite restrictive about which manifolds have |
Summary
Add 11 GPU-native overrides for
PowerManifold(Euclidean)onCuArray, plus test/benchmark restructuring.GPU Overrides
Each override replaces ManifoldsBase's default
for i in get_iterator(M)per-element loop with fused GPU broadcasting, avoiding scalar indexing on CuArray:exp!q .= p .+ Xlog!X .= q .- pdistancesqrt(sum((p .- q) .^ 2))log!+normin loopinnersum(X .* Y)_inner_rhelpernormsqrt(sum(X .^ 2))_norm_r(independent ofinner)parallel_transport_to!Y .= Xproject!(point)q .= pproject!(vector)Y .= Xzero_vector!X .= zero(T)mid_point!q .= (p1 .+ p2) ./ 2vector_transport_to!Y .= XAll 11 functions have independent per-element loops in ManifoldsBase's
PowerManifoldTest Restructuring
test/jlarray/andtest/cuda/subdirectories with per-manifold filesruntests.jldispatches JLArray tests always, CUDA tests only whenCUDA.functional()Benchmark Restructuring
benchmarks/common.jlbenchmarks/main.jl→benchmarks/Stiefel.jlbenchmarks/Euclidean.jlcovering all 11 overridesBenchmark Results (RTX 3090, batch=2048, n=32, k=16)
Test plan
is_pointmanifold membership checks for exp! resultsisapprox