|
| 1 | +using LinearAlgebra |
| 2 | + |
| 3 | +""" |
| 4 | + exp!(M::PowerManifold{ℝ, <:SymmetricPositiveDefinite, ...}, q, p, X) |
| 5 | +
|
| 6 | +GPU-assisted geodesic exponential on batched SPD matrices. |
| 7 | +
|
| 8 | +Strategy: serial loop over batch — each (n,n) SPD exp uses CPU eigendecomposition |
| 9 | +via eigen(Symmetric(...)). Slices are transferred CPU↔GPU per batch element. |
| 10 | +
|
| 11 | +TODO: Replace with true GPU-batched implementation when CUDA.CUSOLVER exposes |
| 12 | +syevjBatched (batched symmetric eigendecomposition). |
| 13 | +
|
| 14 | +Formula: exp_p(X) = p^{1/2} · expm(p^{-1/2} · X · p^{-1/2}) · p^{1/2} |
| 15 | +""" |
| 16 | +function ManifoldsBase.exp!( |
| 17 | + ::PowerManifold{ℝ, <:SymmetricPositiveDefinite, <:Tuple, ArrayPowerRepresentation}, |
| 18 | + q::CuArray{T, 3}, |
| 19 | + p::CuArray{T, 3}, |
| 20 | + X::CuArray{T, 3}, |
| 21 | +) where {T <: Real} |
| 22 | + batch = size(q, 3) |
| 23 | + for i in 1:batch |
| 24 | + # Transfer to CPU for eigendecomposition (no batched GPU alternative yet) |
| 25 | + p_i = Array(@view p[:, :, i]) |
| 26 | + X_i = Array(@view X[:, :, i]) |
| 27 | + |
| 28 | + # Eigendecomposition of p_i (symmetric PD): p = V * diag(λ) * V' |
| 29 | + λ, V = eigen(Symmetric(p_i)) |
| 30 | + sqrt_λ = sqrt.(λ) |
| 31 | + invsqrt_λ = inv.(sqrt_λ) |
| 32 | + p_sqrt = V * Diagonal(sqrt_λ) * V' |
| 33 | + p_invsqrt = V * Diagonal(invsqrt_λ) * V' |
| 34 | + |
| 35 | + # Symmetric inner matrix: S = p^{-1/2} · X · p^{-1/2} |
| 36 | + S = Symmetric(p_invsqrt * X_i * p_invsqrt) |
| 37 | + |
| 38 | + # Matrix exponential of symmetric S via eigendecomposition |
| 39 | + μ, W = eigen(S) |
| 40 | + expS = W * Diagonal(exp.(μ)) * W' |
| 41 | + |
| 42 | + # Geodesic: q = p^{1/2} · exp(S) · p^{1/2} |
| 43 | + # Symmetrize to cancel floating-point rounding: product of symmetric |
| 44 | + # matrices is only approximately symmetric in finite precision. |
| 45 | + result = p_sqrt * expS * p_sqrt |
| 46 | + @view(q[:, :, i]) .= CuArray(T.(0.5 .* (result .+ result'))) |
| 47 | + end |
| 48 | + |
| 49 | + return q |
| 50 | +end |
0 commit comments