Skip to content

Commit 8272708

Browse files
zazabapclaude
andcommitted
feat: GPU-assisted exp! for PowerManifold(SymmetricPositiveDefinite)
Serial batch implementation: each SPD exp uses CPU eigendecomposition (no batched GPU eigendecomp available in CUDA.jl at time of writing). Formula: exp_p(X) = p^{1/2} · expm(p^{-1/2} · X · p^{-1/2}) · p^{1/2} Output is symmetrized (0.5*(result + result')) to ensure exact symmetry within Manifolds.jl's is_point tolerance. TODO: replace with cusolverDnSsyevjBatched when exposed in CUDA.jl. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8cce58d commit 8272708

4 files changed

Lines changed: 104 additions & 0 deletions

File tree

ext/ManifoldsGPUCUDAExt/ManifoldsGPUCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("Grassmann.jl")
1212
include("Sphere.jl")
1313
include("Euclidean.jl")
1414
include("UnitaryMatrices.jl")
15+
include("SPD.jl")
1516

1617

1718
end

ext/ManifoldsGPUCUDAExt/SPD.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using LinearAlgebra
2+
3+
"""
4+
exp!(M::PowerManifold{ℝ, <:SymmetricPositiveDefinite, ...}, q, p, X)
5+
6+
GPU-assisted geodesic exponential on batched SPD matrices.
7+
8+
Strategy: serial loop over batch — each (n,n) SPD exp uses CPU eigendecomposition
9+
via eigen(Symmetric(...)). Slices are transferred CPU↔GPU per batch element.
10+
11+
TODO: Replace with true GPU-batched implementation when CUDA.CUSOLVER exposes
12+
syevjBatched (batched symmetric eigendecomposition).
13+
14+
Formula: exp_p(X) = p^{1/2} · expm(p^{-1/2} · X · p^{-1/2}) · p^{1/2}
15+
"""
16+
function ManifoldsBase.exp!(
17+
::PowerManifold{ℝ, <:SymmetricPositiveDefinite, <:Tuple, ArrayPowerRepresentation},
18+
q::CuArray{T, 3},
19+
p::CuArray{T, 3},
20+
X::CuArray{T, 3},
21+
) where {T <: Real}
22+
batch = size(q, 3)
23+
for i in 1:batch
24+
# Transfer to CPU for eigendecomposition (no batched GPU alternative yet)
25+
p_i = Array(@view p[:, :, i])
26+
X_i = Array(@view X[:, :, i])
27+
28+
# Eigendecomposition of p_i (symmetric PD): p = V * diag(λ) * V'
29+
λ, V = eigen(Symmetric(p_i))
30+
sqrt_λ = sqrt.(λ)
31+
invsqrt_λ = inv.(sqrt_λ)
32+
p_sqrt = V * Diagonal(sqrt_λ) * V'
33+
p_invsqrt = V * Diagonal(invsqrt_λ) * V'
34+
35+
# Symmetric inner matrix: S = p^{-1/2} · X · p^{-1/2}
36+
S = Symmetric(p_invsqrt * X_i * p_invsqrt)
37+
38+
# Matrix exponential of symmetric S via eigendecomposition
39+
μ, W = eigen(S)
40+
expS = W * Diagonal(exp.(μ)) * W'
41+
42+
# Geodesic: q = p^{1/2} · exp(S) · p^{1/2}
43+
# Symmetrize to cancel floating-point rounding: product of symmetric
44+
# matrices is only approximately symmetric in finite precision.
45+
result = p_sqrt * expS * p_sqrt
46+
@view(q[:, :, i]) .= CuArray(T.(0.5 .* (result .+ result')))
47+
end
48+
49+
return q
50+
end

test/cuda_tests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,38 @@ using CUDA
327327
@test is_point(MP, Array(Y_cu))
328328
@test isapprox(MP, p, Array(Y_cu), Y_cpu; atol = 2.0f-4, rtol = 2.0f-4)
329329
end
330+
331+
@testset "SPD exp Float64" begin
332+
Random.seed!(90)
333+
M = SymmetricPositiveDefinite(4)
334+
MP = PowerManifold(M, 16)
335+
336+
p = rand(MP)
337+
X = 0.2 .* rand(MP; vector_at = p)
338+
Y_cpu = exp(MP, p, X)
339+
340+
p_cu = CuArray(p)
341+
X_cu = CuArray(X)
342+
Y_cu = exp(MP, p_cu, X_cu)
343+
344+
@test is_point(MP, Array(Y_cu))
345+
@test isapprox(MP, p, Array(Y_cu), Y_cpu; atol = 2.0e-12, rtol = 2.0e-12)
346+
end
347+
348+
@testset "SPD exp Float32" begin
349+
Random.seed!(91)
350+
M = SymmetricPositiveDefinite(4)
351+
MP = PowerManifold(M, 16)
352+
353+
p = Float32.(rand(MP))
354+
X = Float32(0.2) .* Float32.(rand(MP; vector_at = p))
355+
Y_cpu = exp(MP, p, X)
356+
357+
p_cu = CuArray(p)
358+
X_cu = CuArray(X)
359+
Y_cu = exp(MP, p_cu, X_cu)
360+
361+
@test is_point(MP, Array(Y_cu))
362+
@test isapprox(MP, p, Array(Y_cu), Y_cpu; atol = 2.0f-5, rtol = 2.0f-5)
363+
end
330364
end

test/jlarray_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,22 @@ end
193193
@test is_point(MP, Array(Y_jl))
194194
@test isapprox(MP, p, Array(Y_jl), Y_cpu; atol = 2.0e-12, rtol = 2.0e-12)
195195
end
196+
197+
@testset "JLArray: SPD exp Float64" begin
198+
M = SymmetricPositiveDefinite(4)
199+
MP = PowerManifold(M, 8)
200+
201+
Random.seed!(90)
202+
p = rand(MP)
203+
X = 0.2 .* rand(MP; vector_at = p)
204+
Y_cpu = exp(MP, p, X)
205+
206+
p_jl = JLArray(p)
207+
X_jl = JLArray(X)
208+
Y_jl = exp(MP, p_jl, X_jl)
209+
210+
# Note: no is_point check here — the CPU Manifolds.jl exp for SPD produces
211+
# results with ~O(eps) asymmetry that fails Manifolds.jl's strict symmetry
212+
# tolerance. The GPU implementation (SPD.jl) symmetrizes explicitly.
213+
@test isapprox(MP, p, Array(Y_jl), Y_cpu; atol = 2.0e-12, rtol = 2.0e-12)
214+
end

0 commit comments

Comments
 (0)