diff --git a/Project.toml b/Project.toml index 0c27401..7a1944f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,13 @@ version = "0.1.0-DEV" authors = ["Mateusz Baran and contributors"] [deps] +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" +Runic = "62bfec6d-59d7-401d-8490-b29ee721c001" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -17,15 +20,20 @@ ManifoldsGPUCUDAExt = "CUDA" [compat] CUDA = "5.9.6" +GPUArrays = "11" +JLArrays = "0.3" ManifoldDiff = "0.4.5" Manifolds = "0.11.12" ManifoldsBase = "2.3.1" +Runic = "1.5.1" julia = "1.10.10" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "CUDA", "Random"] +test = ["Test", "CUDA", "GPUArrays", "JLArrays", "Random"] diff --git a/benchmarks/GeneralUnitaryMatrices.jl b/benchmarks/GeneralUnitaryMatrices.jl new file mode 100644 index 0000000..defc506 --- /dev/null +++ b/benchmarks/GeneralUnitaryMatrices.jl @@ -0,0 +1,282 @@ +using Random +using Statistics +using LinearAlgebra + +using ManifoldsGPU +using Manifolds +using ManifoldsBase +using CUDA + +function _time_median(f; samples::Int = 6) + timings = Vector{Float64}(undef, samples) + for i in 1:samples + GC.gc() + t0 = time_ns() + f() + timings[i] = (time_ns() - t0) / 1.0e6 + end + return median(timings), timings +end + +function _benchmark_cpu_gpu(cpu_f, gpu_f; samples::Int) + cpu_f() + gpu_f() + + cpu_ms, cpu_all = _time_median(cpu_f; samples = samples) + gpu_ms, gpu_all = _time_median(gpu_f; samples = samples) + + return cpu_ms, cpu_all, gpu_ms, gpu_all +end + +function _print_results(; + name::String, + manifold_label::String, + samples::Int, + cpu_all, + gpu_all, + cpu_ms::Float64, + gpu_ms::Float64, + relerr, + relerr_label::String, + extra_lines::Vector{String} = String[], + ) + speedup = cpu_ms / gpu_ms + + println("=== ManifoldsGPU benchmark: $name on $manifold_label ===") + println("Element type: Float32") + for line in extra_lines + println(line) + end + println("Samples: $samples") + println("CPU times [ms]: ", round.(cpu_all; digits = 2)) + println("GPU times [ms]: ", round.(gpu_all; digits = 2)) + println("Median CPU [ms]: ", round(cpu_ms; digits = 2)) + println("Median GPU [ms]: ", round(gpu_ms; digits = 2)) + println("Speedup (CPU/GPU): ", round(speedup; digits = 2), "x") + return println("Relative error $relerr_label: ", relerr) +end + +function _parse_arg(i::Int, default) + return length(ARGS) >= i ? parse(typeof(default), ARGS[i]) : default +end + +function _setup_rotations_data(; n::Int, batch::Int, scale::Float32, seed::Int) + Random.seed!(seed) + + M = Rotations(n) + MP = PowerManifold(M, batch) + + p_cpu = Float32.(rand(MP)) + X_cpu = scale .* Float32.(rand(MP; vector_at = p_cpu)) + q_cpu = Float32.(rand(MP)) + + p_gpu = CuArray(p_cpu) + X_gpu = CuArray(X_cpu) + q_gpu = CuArray(q_cpu) + + return (; MP, p_cpu, X_cpu, q_cpu, p_gpu, X_gpu, q_gpu) +end + +function _setup_unitary_data(; n::Int, batch::Int, scale::Float32, seed::Int) + Random.seed!(seed) + + M = UnitaryMatrices(n) + MP = PowerManifold(M, batch) + + p_cpu = ComplexF32.(rand(MP)) + X_cpu = scale .* ComplexF32.(rand(MP; vector_at = p_cpu)) + q_cpu = ComplexF32.(rand(MP)) + + p_gpu = CuArray(p_cpu) + X_gpu = CuArray(X_cpu) + q_gpu = CuArray(q_cpu) + + return (; MP, p_cpu, X_cpu, q_cpu, p_gpu, X_gpu, q_gpu) +end + +function benchmark_rotations(; + n::Int = 16, + batch::Int = 2048, + scale::Float32 = 0.2f0, + t::Float32 = 0.3f0, + samples::Int = 6, + seed::Int = 1234, + ) + data = _setup_rotations_data(; n = n, batch = batch, scale = scale, seed = seed) + MP = data.MP + p_cpu = data.p_cpu + X_cpu = data.X_cpu + q_cpu = data.q_cpu + p_gpu = data.p_gpu + X_gpu = data.X_gpu + q_gpu = data.q_gpu + + manifold_label = "PowerManifold(Rotations($n), $batch)" + + # exp! + cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu( + () -> exp(MP, p_cpu, X_cpu), + () -> CUDA.@sync exp(MP, p_gpu, X_gpu); + samples = samples, + ) + relerr = begin + Y_cpu = exp(MP, p_cpu, X_cpu) + Y_gpu = Array(CUDA.@sync exp(MP, p_gpu, X_gpu)) + norm(Y_cpu .- Y_gpu) / max(norm(Y_cpu), eps(Float32)) + end + _print_results(; + name = "ExponentialRetraction", + manifold_label = manifold_label, + samples = samples, + cpu_all = cpu_all, + gpu_all = gpu_all, + cpu_ms = cpu_ms, + gpu_ms = gpu_ms, + relerr = relerr, + relerr_label = "||Ycpu - Ygpu||/||Ycpu||", + ) + + println() + + # log! + cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu( + () -> log(MP, p_cpu, q_cpu), + () -> CUDA.@sync log(MP, p_gpu, q_gpu); + samples = samples, + ) + relerr = begin + X_log_cpu = log(MP, p_cpu, q_cpu) + X_log_gpu = Array(CUDA.@sync log(MP, p_gpu, q_gpu)) + norm(X_log_cpu .- X_log_gpu) / max(norm(X_log_cpu), eps(Float32)) + end + _print_results(; + name = "LogarithmicMap", + manifold_label = manifold_label, + samples = samples, + cpu_all = cpu_all, + gpu_all = gpu_all, + cpu_ms = cpu_ms, + gpu_ms = gpu_ms, + relerr = relerr, + relerr_label = "||Xcpu - Xgpu||/||Xcpu||", + ) + + println() + + # retract polar + q_cpu = similar(p_cpu) + q_gpu = similar(p_gpu) + cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu( + () -> ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, PolarRetraction()), + () -> CUDA.@sync ManifoldsBase.retract_fused!( + MP, q_gpu, p_gpu, X_gpu, t, PolarRetraction(), + ); + samples = samples, + ) + relerr = begin + ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, PolarRetraction()) + CUDA.@sync ManifoldsBase.retract_fused!( + MP, q_gpu, p_gpu, X_gpu, t, PolarRetraction(), + ) + q_gpu_h = Array(q_gpu) + norm(q_cpu .- q_gpu_h) / max(norm(q_cpu), eps(Float32)) + end + return _print_results(; + name = "PolarRetraction", + manifold_label = manifold_label, + samples = samples, + cpu_all = cpu_all, + gpu_all = gpu_all, + cpu_ms = cpu_ms, + gpu_ms = gpu_ms, + relerr = relerr, + relerr_label = "||Qcpu - Qgpu||/||Qcpu||", + extra_lines = ["Retraction scalar t: $t"], + ) +end + +function benchmark_unitary(; + n::Int = 16, + batch::Int = 2048, + scale::Float32 = 0.2f0, + samples::Int = 6, + seed::Int = 1234, + ) + data = _setup_unitary_data(; n = n, batch = batch, scale = scale, seed = seed) + MP = data.MP + p_cpu = data.p_cpu + X_cpu = data.X_cpu + q_cpu = data.q_cpu + p_gpu = data.p_gpu + X_gpu = data.X_gpu + q_gpu = data.q_gpu + + manifold_label = "PowerManifold(UnitaryMatrices($n), $batch)" + + # exp! + cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu( + () -> exp(MP, p_cpu, X_cpu), + () -> CUDA.@sync exp(MP, p_gpu, X_gpu); + samples = samples, + ) + relerr = begin + Y_cpu = exp(MP, p_cpu, X_cpu) + Y_gpu = Array(CUDA.@sync exp(MP, p_gpu, X_gpu)) + norm(Y_cpu .- Y_gpu) / max(norm(Y_cpu), eps(Float32)) + end + _print_results(; + name = "ExponentialRetraction", + manifold_label = manifold_label, + samples = samples, + cpu_all = cpu_all, + gpu_all = gpu_all, + cpu_ms = cpu_ms, + gpu_ms = gpu_ms, + relerr = relerr, + relerr_label = "||Ycpu - Ygpu||/||Ycpu||", + ) + + println() + + # log! + cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu( + () -> log(MP, p_cpu, q_cpu), + () -> CUDA.@sync log(MP, p_gpu, q_gpu); + samples = samples, + ) + relerr = begin + X_log_cpu = log(MP, p_cpu, q_cpu) + X_log_gpu = Array(CUDA.@sync log(MP, p_gpu, q_gpu)) + norm(X_log_cpu .- X_log_gpu) / max(norm(X_log_cpu), eps(Float32)) + end + return _print_results(; + name = "LogarithmicMap", + manifold_label = manifold_label, + samples = samples, + cpu_all = cpu_all, + gpu_all = gpu_all, + cpu_ms = cpu_ms, + gpu_ms = gpu_ms, + relerr = relerr, + relerr_label = "||Xcpu - Xgpu||/||Xcpu||", + ) +end + +function main() + n = _parse_arg(1, 16) + batch = _parse_arg(2, 2048) + samples = _parse_arg(3, 6) + + println("=== Rotations benchmarks ===") + println("Running with n=$n, batch=$batch, samples=$samples") + println() + benchmark_rotations(; n = n, batch = batch, samples = samples) + + println() + println("=== UnitaryMatrices benchmarks ===") + println("Running with n=$n, batch=$batch, samples=$samples") + println() + return benchmark_unitary(; n = n, batch = batch, samples = samples) +end + +main() diff --git a/ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl b/ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl new file mode 100644 index 0000000..b486bd5 --- /dev/null +++ b/ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl @@ -0,0 +1,128 @@ +using LinearAlgebra + +function ManifoldsBase.exp!( + ::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + q::CuArray{T, 3}, + p::CuArray{T, 3}, + X::CuArray{T, 3}, + ) where {T <: Number} + E = _matrix_exp_gpu(X) + q .= CUDA.CUBLAS.gemm_strided_batched('N', 'N', p, E) + return q +end + +function ManifoldsBase.retract_polar_fused!( + ::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + q::CuArray{T, 3}, + p::CuArray{T, 3}, + X::CuArray{T, 3}, + t::Number, + ) where {T <: Number} + q .= p .+ CUDA.CUBLAS.gemm_strided_batched('N', 'N', p, T(t) .* X) + + # NOTE: This fallback block is intentionally non-differentiable. + # retract! functions are not differentiated through directly. + try + U, _, V = CUDA.CUSOLVER.gesvdj!('V', q) + q .= CUDA.CUBLAS.gemm_strided_batched('N', 'C', U, V) + catch e + if e isa ArgumentError + # CPU fallback: gesvdj! fails for matrices larger than supported size + batch = size(q, 3) + for i in 1:batch + q_i = copy(@view q[:, :, i]) + s = svd!(q_i) + @view(q[:, :, i]) .= s.U * s.Vt + end + else + rethrow() + end + end + + return q +end + +function ManifoldsBase.retract_fused!( + M::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + q::CuArray{T, 3}, + p::CuArray{T, 3}, + X::CuArray{T, 3}, + t::Number, + ::PolarRetraction, + ) where {T <: Number} + return ManifoldsBase.retract_polar_fused!(M, q, p, X, t) +end + +function ManifoldsBase.project!( + ::PowerManifold{ + <:Any, + <:Manifolds.GeneralUnitaryMatrices{ + <:Any, + <:Any, + Manifolds.AbsoluteDeterminantOneMatrixType, + }, + <:Tuple, + ArrayPowerRepresentation, + }, + q::CuArray{T, 3}, + p::CuArray{T, 3}, + ) where {T <: Number} + q .= p + + # NOTE: This fallback block is intentionally non-differentiable. + # project! for points is not differentiated through directly. + try + U, _, V = CUDA.CUSOLVER.gesvdj!('V', q) + q .= CUDA.CUBLAS.gemm_strided_batched('N', 'C', U, V) + catch e + if e isa ArgumentError + # CPU fallback: gesvdj! fails for matrices larger than supported size + batch = size(q, 3) + for i in 1:batch + q_i = copy(@view q[:, :, i]) + s = svd!(q_i) + @view(q[:, :, i]) .= s.U * s.Vt + end + else + rethrow() + end + end + + return q +end + +function ManifoldsBase.project!( + ::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + Y::CuArray{T, 3}, + p::CuArray{T, 3}, + X::CuArray{T, 3}, + ) where {T <: Number} + A = CUDA.CUBLAS.gemm_strided_batched('C', 'N', p, X) # p' * X + B = CUDA.CUBLAS.gemm_strided_batched('C', 'N', X, p) # X' * p + Y .= (A .- B) ./ 2 + return Y +end + +function ManifoldsBase.log!( + ::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + X::CuArray{T, 3}, + p::CuArray{T, 3}, + q::CuArray{T, 3}, + ) where {T <: Real} + U = CUDA.CUBLAS.gemm_strided_batched('T', 'N', p, q) + X .= _matrix_log_gpu(U) + X .= (X .- permutedims(X, (2, 1, 3))) ./ T(2) + return X +end + +function ManifoldsBase.log!( + ::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation}, + X::CuArray{T, 3}, + p::CuArray{T, 3}, + q::CuArray{T, 3}, + ) where {T <: Complex} + U = CUDA.CUBLAS.gemm_strided_batched('C', 'N', p, q) + X .= _matrix_log_gpu(U) + X .= (X .- conj.(permutedims(X, (2, 1, 3)))) ./ T(2) + return X +end diff --git a/ext/ManifoldsGPUCUDAExt/ManifoldsGPUCUDAExt.jl b/ext/ManifoldsGPUCUDAExt/ManifoldsGPUCUDAExt.jl index 1d1d586..8cbb550 100644 --- a/ext/ManifoldsGPUCUDAExt/ManifoldsGPUCUDAExt.jl +++ b/ext/ManifoldsGPUCUDAExt/ManifoldsGPUCUDAExt.jl @@ -9,5 +9,6 @@ include("helpers.jl") include("Stiefel.jl") +include("GeneralUnitaryMatrices.jl") end diff --git a/ext/ManifoldsGPUCUDAExt/helpers.jl b/ext/ManifoldsGPUCUDAExt/helpers.jl index 85f60bb..d0d1d45 100644 --- a/ext/ManifoldsGPUCUDAExt/helpers.jl +++ b/ext/ManifoldsGPUCUDAExt/helpers.jl @@ -3,6 +3,11 @@ function _matrix_exp_gpu(A::CuArray{T, 2}) where {T <: Real} return reshape(E, size(A)) end +function _matrix_exp_gpu(A::CuArray{T, 2}) where {T <: Complex} + E = _matrix_exp_gpu(reshape(A, size(A, 1), size(A, 2), 1)) + return reshape(E, size(A)) +end + function _matrix_exp_gpu(A::CuArray{T, 3}) where {T <: Real} n, m, batch = size(A) n == m || throw(DimensionMismatch("matrix exponential requires square matrices, got ($n, $m, $batch)")) @@ -31,3 +36,125 @@ function _matrix_exp_gpu(A::CuArray{T, 3}) where {T <: Real} return E end + +function _matrix_exp_gpu(A::CuArray{T, 3}) where {T <: Complex} + n, m, batch = size(A) + n == m || + throw(DimensionMismatch("matrix exponential requires square matrices, got ($n, $m, $batch)")) + + RT = real(T) + I_n = reshape(CuArray(Matrix{T}(I, n, n)), n, n, 1) + I_batch = similar(A) + I_batch .= I_n + + maxabs = maximum(abs, A) + theta = RT <: Float32 ? RT(1) : RT(1 // 2) + s = max(0, ceil(Int, log2(float(real(maxabs) * n / theta)))) + As = A ./ T(2)^s + + order = RT <: Float32 ? 18 : 30 + E = copy(I_batch) + term = copy(I_batch) + for j in 1:order + term = CUDA.CUBLAS.gemm_strided_batched('N', 'N', term, As) + term ./= T(j) + E .+= term + end + + for _ in 1:s + E = CUDA.CUBLAS.gemm_strided_batched('N', 'N', E, E) + end + + return E +end + +""" + _matrix_log_gpu(A::CuArray{T, 2}) + _matrix_log_gpu(A::CuArray{T, 3}) + +GPU matrix logarithm via Inverse Scaling & Squaring with Denman-Beavers iteration. + +Uses the identity `log(A) = 2^s * log(A^{1/2^s})`, where repeated matrix square +roots bring `A` close to `I`, then a Taylor series computes `log(I + X)`. Matrix +square roots are computed via the Denman-Beavers iteration using batched LU-based +inversion (`getrf_strided_batched!` + `getri_strided_batched!`). + +Real inputs are promoted to complex (square roots of unitary matrices may have +complex entries), then the real part is taken. + +Parameters (tuned for `Rotations(n)` with `n ≤ 32`): +- `sqrtm_count=4`: number of repeated square roots (scaling factor `2^s`) +- `db_iters=10`: Denman-Beavers iterations per square root +- `taylor_order=16`: terms in `log(I+X)` Taylor series +""" +function _matrix_log_gpu(A::CuArray{T, 2}) where {T <: Real} + L = _matrix_log_gpu(reshape(A, size(A, 1), size(A, 2), 1)) + return reshape(L, size(A)) +end + +function _matrix_log_gpu(A::CuArray{T, 2}) where {T <: Complex} + L = _matrix_log_gpu(reshape(A, size(A, 1), size(A, 2), 1)) + return reshape(L, size(A)) +end + +# Batched matrix inverse via LU factorization (used by Denman-Beavers) +function _batched_inv_gpu(A::CuArray{T, 3}) where {T} + A_lu = copy(A) + pivot = CUDA.zeros(Int32, size(A, 1), size(A, 3)) + CUDA.CUBLAS.getrf_strided_batched!(A_lu, pivot) + C = similar(A) + CUDA.CUBLAS.getri_strided_batched!(A_lu, C, pivot) + return C +end + +# Denman-Beavers iteration for batched matrix square root +function _batched_sqrtm_gpu(A::CuArray{T, 3}; iters::Int = 10) where {T} + nn = size(A, 1) + I_n = reshape(CuArray(Matrix{T}(I, nn, nn)), nn, nn, 1) + Y = copy(A) + Z = similar(A) + Z .= I_n + for _ in 1:iters + Zinv = _batched_inv_gpu(Z) + Yinv = _batched_inv_gpu(Y) + Y = (Y .+ Zinv) ./ T(2) + Z = (Z .+ Yinv) ./ T(2) + end + return Y +end + +function _matrix_log_gpu( + A::CuArray{T, 3}; sqrtm_count::Int = 4, db_iters::Int = 10, taylor_order::Int = 16, + ) where {T <: Complex} + n, m, batch = size(A) + n == m || + throw(DimensionMismatch("matrix logarithm requires square matrices, got ($n, $m, $batch)")) + + # Inverse Scaling: take s repeated square roots to bring A close to I + B = copy(A) + for _ in 1:sqrtm_count + B = _batched_sqrtm_gpu(B; iters = db_iters) + end + + # Taylor series: log(I + X) = X - X²/2 + X³/3 - ... + I_n = reshape(CuArray(Matrix{T}(I, n, n)), n, n, 1) + X_mat = B .- I_n + L = copy(X_mat) + term = copy(X_mat) + for j in 2:taylor_order + term = CUDA.CUBLAS.gemm_strided_batched('N', 'N', term, X_mat) + sign_j = iseven(j) ? T(-1) : T(1) + L .+= sign_j .* term ./ T(j) + end + + # Squaring: undo the scaling + L .*= T(2)^sqrtm_count + return L +end + +function _matrix_log_gpu(A::CuArray{T, 3}) where {T <: Real} + CT = complex(T) + Ac = CuArray{CT}(A) + logAc = _matrix_log_gpu(Ac) + return real.(logAc) +end diff --git a/test/cuda/test_general_unitary_matrices.jl b/test/cuda/test_general_unitary_matrices.jl new file mode 100644 index 0000000..f960998 --- /dev/null +++ b/test/cuda/test_general_unitary_matrices.jl @@ -0,0 +1,309 @@ +@testset "GeneralUnitaryMatrices CUDA" begin + + # 1. Rotations exp! batched + @testset "Rotations exp! batched" begin + Random.seed!(42) + + M = Rotations(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = rand(MP) + X = 0.25 * rand(MP; vector_at = p) + Y_cpu = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test is_point(MP, Y_cu_h) + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end + + # 2. Rotations exp! Float32 + @testset "Rotations exp! Float32" begin + Random.seed!(43) + + M = Rotations(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = Float32.(rand(MP)) + X = Float32(0.25) .* Float32.(rand(MP; vector_at = p)) + Y_cpu = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test is_point(MP, Y_cu_h) + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0f-5, rtol = 2.0f-5) + end + end + + # 3. UnitaryMatrices exp! batched + @testset "UnitaryMatrices exp! batched" begin + Random.seed!(44) + + M = UnitaryMatrices(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = rand(MP) + X = 0.25 * rand(MP; vector_at = p) + Y_cpu = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test is_point(MP, Y_cu_h) + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end + + # 4. Rotations retract_polar_fused batched + @testset "Rotations retract_polar_fused batched" begin + Random.seed!(46) + + M = Rotations(8) + MP = PowerManifold(M, 64) + t = 0.3 + + p = rand(MP) + X = rand(MP; vector_at = p) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14) + end + + # 5. Rotations retract_polar_fused Float32 + @testset "Rotations retract_polar_fused Float32" begin + Random.seed!(47) + + M = Rotations(8) + MP = PowerManifold(M, 64) + t = Float32(0.3) + + p = Float32.(rand(MP)) + X = Float32.(rand(MP; vector_at = p)) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5) + end + + # 6. UnitaryMatrices retract_polar_fused batched + @testset "UnitaryMatrices retract_polar_fused batched" begin + Random.seed!(48) + + M = UnitaryMatrices(8) + MP = PowerManifold(M, 64) + t = 0.3 + + p = rand(MP) + X = 0.25 .* rand(MP; vector_at = p) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5) + end + + # 7. Rotations project! tangent + @testset "Rotations project! tangent" begin + Random.seed!(50) + + M = Rotations(8) + MP = PowerManifold(M, 64) + + p = rand(MP) + X = randn(size(p)...) + + Y_cpu = similar(X) + for i in 1:size(p, 3) + ManifoldsBase.project!( + M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i) + ) + end + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = similar(X_cu) + ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + + # 8. UnitaryMatrices project! tangent + @testset "UnitaryMatrices project! tangent" begin + Random.seed!(51) + + M = UnitaryMatrices(8) + MP = PowerManifold(M, 64) + + p = rand(MP) + X = randn(ComplexF64, size(p)...) + + Y_cpu = similar(X) + for i in 1:size(p, 3) + ManifoldsBase.project!( + M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i) + ) + end + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = similar(X_cu) + ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + + # 9. OrthogonalMatrices project! point + @testset "OrthogonalMatrices project! point" begin + Random.seed!(52) + + M = OrthogonalMatrices(8) + MP = PowerManifold(M, 64) + + p = rand(MP) + p_noisy = p .+ 0.01 .* randn(size(p)...) + + q_cpu = similar(p) + for i in 1:size(p, 3) + s = svd(p_noisy[:, :, i]) + q_cpu[:, :, i] .= s.U * s.Vt + end + + p_noisy_cu = CuArray(p_noisy) + q_cu = similar(p_noisy_cu) + ManifoldsBase.project!(MP, q_cu, p_noisy_cu) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(q_cu_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + + # 11. Rotations log! batched + @testset "Rotations log! batched" begin + Random.seed!(60) + + M = Rotations(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = rand(MP) + q = rand(MP) + X_cpu = log(MP, p, q) + + p_cu = CuArray(p) + q_cu = CuArray(q) + X_cu = log(MP, p_cu, q_cu) + X_cu_h = Array(X_cu) + + @test is_vector(MP, p, X_cu_h; atol = 2.0e-6) + @test isapprox(X_cu_h, X_cpu; atol = 2.0e-6, rtol = 2.0e-6) + end + end + + # 12. Rotations log! Float32 + @testset "Rotations log! Float32" begin + Random.seed!(61) + + M = Rotations(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = Float32.(rand(MP)) + q = Float32.(rand(MP)) + X_cpu = log(MP, p, q) + + p_cu = CuArray(p) + q_cu = CuArray(q) + X_cu = log(MP, p_cu, q_cu) + X_cu_h = Array(X_cu) + + @test is_vector(MP, p, X_cu_h; atol = 1.0f-3) + @test isapprox(X_cu_h, X_cpu; atol = 2.0f-3, rtol = 2.0f-3) + end + end + + # 13. UnitaryMatrices log! batched + @testset "UnitaryMatrices log! batched" begin + Random.seed!(62) + + M = UnitaryMatrices(8) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = rand(MP) + q = rand(MP) + X_cpu = log(MP, p, q) + + p_cu = CuArray(p) + q_cu = CuArray(q) + X_cu = log(MP, p_cu, q_cu) + X_cu_h = Array(X_cu) + + @test is_vector(MP, p, X_cu_h; atol = 1.0e-4) + @test isapprox(X_cu_h, X_cpu; atol = 1.0e-4, rtol = 1.0e-4) + end + end + + # 14. retract_polar_fused fallback stress (large matrices trigger CPU fallback) + @testset "retract_polar_fused fallback stress" begin + Random.seed!(53) + + M = Rotations(48) + MP = PowerManifold(M, 8) + t = 0.2 + + for _ in 1:4 + p = rand(MP) + X = 0.2 * rand(MP; vector_at = p) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(q_cu_h, q; atol = 2.0e-12, rtol = 2.0e-12) + end + end +end diff --git a/test/cuda/test_stiefel.jl b/test/cuda/test_stiefel.jl new file mode 100644 index 0000000..ec2e0cf --- /dev/null +++ b/test/cuda/test_stiefel.jl @@ -0,0 +1,128 @@ +@testset "Stiefel CUDA" begin + @testset "exp! basic" begin + M = Stiefel(4, 2) + MP = PowerManifold(M, 5) + + p = rand(MP) + X = rand(MP; vector_at = p) + Y = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + @test isapprox(MP, p, Array(Y_cu), Y; atol = 1.0e-10) + end + + @testset "exp! batched stress" begin + Random.seed!(42) + + M = Stiefel(8, 4) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = rand(MP) + X = 0.25 * rand(MP; vector_at = p) + Y = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test is_point(MP, Y_cu_h) + @test isapprox(MP, p, Y_cu_h, Y; atol = 2.0e-14, rtol = 2.0e-14) + end + end + + @testset "exp! batched stress Float32" begin + Random.seed!(43) + + M = Stiefel(8, 4) + MP = PowerManifold(M, 64) + + for _ in 1:6 + p = Float32.(rand(MP)) + X = Float32(0.25) .* Float32.(rand(MP; vector_at = p)) + Y = exp(MP, p, X) + + p_cu = CuArray(p) + X_cu = CuArray(X) + Y_cu = exp(MP, p_cu, X_cu) + Y_cu_h = Array(Y_cu) + + @test is_point(MP, Y_cu_h) + @test isapprox(MP, p, Y_cu_h, Y; atol = 2.0f-5, rtol = 2.0f-5) + end + end + + @testset "retract_polar_fused batched" begin + Random.seed!(46) + + M = Stiefel(8, 4) + MP = PowerManifold(M, 64) + t = 0.3 + + p = rand(MP) + X = rand(MP; vector_at = p) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(MP, p, q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14) + end + + @testset "retract_polar_fused batched Float32" begin + Random.seed!(47) + + M = Stiefel(8, 4) + MP = PowerManifold(M, 64) + t = Float32(0.3) + + p = Float32.(rand(MP)) + X = Float32.(rand(MP; vector_at = p)) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(MP, p, q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5) + end + + @testset "retract_polar_fused fallback stress" begin + Random.seed!(48) + + M = Stiefel(48, 16) + MP = PowerManifold(M, 8) + t = 0.2 + + for _ in 1:4 + p = rand(MP) + X = 0.2 * rand(MP; vector_at = p) + + q = similar(p) + ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) + + p_cu = CuArray(p) + X_cu = CuArray(X) + q_cu = similar(p_cu) + ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) + q_cu_h = Array(q_cu) + + @test is_point(MP, q_cu_h) + @test isapprox(MP, p, q_cu_h, q; atol = 2.0e-12, rtol = 2.0e-12) + end + end +end diff --git a/test/jlarray/test_general_unitary_matrices.jl b/test/jlarray/test_general_unitary_matrices.jl new file mode 100644 index 0000000..0bb0200 --- /dev/null +++ b/test/jlarray/test_general_unitary_matrices.jl @@ -0,0 +1,72 @@ +using GPUArrays + +# JLArray tests for plain (non-PowerManifold) GeneralUnitaryMatrices. +# These test that generic Manifolds.jl code works with GPU arrays without scalar indexing. +# CuArray-specific PowerManifold overrides are tested in test/cuda/ only. + +@testset "GeneralUnitaryMatrices JLArray" begin + GPUArrays.allowscalar(false) + + # Rotations retract! (Polar, Float64) + @testset "Rotations retract Polar Float64" begin + Random.seed!(42) + + M = Rotations(4) + + for _ in 1:3 + p = rand(M) + X = 0.25 * rand(M; vector_at = p) + q_cpu = retract(M, p, X, PolarRetraction()) + + p_jl = JLArray(p) + X_jl = JLArray(X) + q_jl = retract(M, p_jl, X_jl, PolarRetraction()) + q_jl_h = Array(q_jl) + + @test is_point(M, q_jl_h) + @test isapprox(q_jl_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end + + # Rotations retract! (Polar, Float32) + @testset "Rotations retract Polar Float32" begin + Random.seed!(43) + + M = Rotations(4) + + for _ in 1:3 + p = Float32.(rand(M)) + X = Float32(0.25) .* Float32.(rand(M; vector_at = p)) + q_cpu = retract(M, p, X, PolarRetraction()) + + p_jl = JLArray(p) + X_jl = JLArray(X) + q_jl = retract(M, p_jl, X_jl, PolarRetraction()) + q_jl_h = Array(q_jl) + + @test is_point(M, q_jl_h) + @test isapprox(q_jl_h, q_cpu; atol = 2.0f-5, rtol = 2.0f-5) + end + end + + # Grassmann retract! (Polar, Float64) + @testset "Grassmann retract Polar Float64" begin + Random.seed!(44) + + M = Grassmann(6, 3) + + for _ in 1:3 + p = rand(M) + X = 0.25 * rand(M; vector_at = p) + q_cpu = retract(M, p, X, PolarRetraction()) + + p_jl = JLArray(p) + X_jl = JLArray(X) + q_jl = retract(M, p_jl, X_jl, PolarRetraction()) + q_jl_h = Array(q_jl) + + @test is_point(M, q_jl_h) + @test isapprox(q_jl_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end +end diff --git a/test/jlarray/test_stiefel.jl b/test/jlarray/test_stiefel.jl new file mode 100644 index 0000000..5b03007 --- /dev/null +++ b/test/jlarray/test_stiefel.jl @@ -0,0 +1,74 @@ +using GPUArrays + +# JLArray tests for plain (non-PowerManifold) Stiefel. +# These test that generic Manifolds.jl code works with GPU arrays without scalar indexing. +# CuArray-specific PowerManifold overrides are tested in test/cuda/ only. + +@testset "Stiefel JLArray" begin + GPUArrays.allowscalar(false) + + # Stiefel retract! (Polar, Float64) + @testset "Stiefel retract Polar Float64" begin + Random.seed!(42) + + M = Stiefel(8, 4) + + for _ in 1:3 + p = rand(M) + X = 0.25 * rand(M; vector_at = p) + q_cpu = retract(M, p, X, PolarRetraction()) + + p_jl = JLArray(p) + X_jl = JLArray(X) + q_jl = retract(M, p_jl, X_jl, PolarRetraction()) + q_jl_h = Array(q_jl) + + @test is_point(M, q_jl_h) + @test isapprox(q_jl_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end + + # Stiefel retract! (Polar, Float32) + @testset "Stiefel retract Polar Float32" begin + Random.seed!(43) + + M = Stiefel(8, 4) + + for _ in 1:3 + p = Float32.(rand(M)) + X = Float32(0.25) .* Float32.(rand(M; vector_at = p)) + q_cpu = retract(M, p, X, PolarRetraction()) + + p_jl = JLArray(p) + X_jl = JLArray(X) + q_jl = retract(M, p_jl, X_jl, PolarRetraction()) + q_jl_h = Array(q_jl) + + @test is_point(M, q_jl_h) + @test isapprox(q_jl_h, q_cpu; atol = 2.0f-5, rtol = 2.0f-5) + end + end + + # Stiefel project! tangent (Float64) + @testset "Stiefel project tangent Float64" begin + Random.seed!(44) + + M = Stiefel(8, 4) + + for _ in 1:3 + p = rand(M) + X = randn(8, 4) + Y_cpu = similar(X) + project!(M, Y_cpu, p, X) + + p_jl = JLArray(p) + X_jl = JLArray(X) + Y_jl = similar(X_jl) + project!(M, Y_jl, p_jl, X_jl) + Y_jl_h = Array(Y_jl) + + @test is_vector(M, p, Y_jl_h; atol = 2.0e-14) + @test isapprox(Y_jl_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 06e9f45..636d503 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,137 +1,27 @@ using ManifoldsGPU using Test using Random +using LinearAlgebra using ManifoldsBase, Manifolds using CUDA +using GPUArrays @testset "ManifoldsGPU.jl" begin - # Write your tests here. - - @testset "Stiefel" begin - M = Stiefel(4, 2) - MP = PowerManifold(M, 5) - - p = rand(MP) - X = rand(MP; vector_at = p) - Y = exp(MP, p, X) - - p_cu = CuArray(p) - X_cu = CuArray(X) - Y_cu = exp(MP, p_cu, X_cu) - @test isapprox(MP, p, Array(Y_cu), Y; atol = 1.0e-10) - end - - @testset "Stiefel batched stress" begin - Random.seed!(42) - - M = Stiefel(8, 4) - MP = PowerManifold(M, 64) - - for _ in 1:6 - p = rand(MP) - X = 0.25 * rand(MP; vector_at = p) - Y = exp(MP, p, X) - - p_cu = CuArray(p) - X_cu = CuArray(X) - Y_cu = exp(MP, p_cu, X_cu) - Y_cu_h = Array(Y_cu) - - @test is_point(MP, Y_cu_h) - @test isapprox(MP, p, Y_cu_h, Y; atol = 2.0e-14, rtol = 2.0e-14) - end - end - - @testset "Stiefel batched stress Float32" begin - Random.seed!(43) - - M = Stiefel(8, 4) - MP = PowerManifold(M, 64) - - for _ in 1:6 - p = Float32.(rand(MP)) - X = Float32(0.25) .* Float32.(rand(MP; vector_at = p)) - Y = exp(MP, p, X) - - p_cu = CuArray(p) - X_cu = CuArray(X) - Y_cu = exp(MP, p_cu, X_cu) - Y_cu_h = Array(Y_cu) - - @test is_point(MP, Y_cu_h) - @test isapprox(MP, p, Y_cu_h, Y; atol = 2.0f-5, rtol = 2.0f-5) - end - end - - @testset "Stiefel retract_polar_fused batched" begin - Random.seed!(46) - - M = Stiefel(8, 4) - MP = PowerManifold(M, 64) - t = 0.3 - - p = rand(MP) - X = rand(MP; vector_at = p) - - q = similar(p) - ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) - - p_cu = CuArray(p) - X_cu = CuArray(X) - q_cu = similar(p_cu) - ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) - q_cu_h = Array(q_cu) - - @test is_point(MP, q_cu_h) - @test isapprox(MP, p, q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14) - end - - @testset "Stiefel retract_polar_fused batched Float32" begin - Random.seed!(47) - - M = Stiefel(8, 4) - MP = PowerManifold(M, 64) - t = Float32(0.3) - - p = Float32.(rand(MP)) - X = Float32.(rand(MP; vector_at = p)) - - q = similar(p) - ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) - - p_cu = CuArray(p) - X_cu = CuArray(X) - q_cu = similar(p_cu) - ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) - q_cu_h = Array(q_cu) - - @test is_point(MP, q_cu_h) - @test isapprox(MP, p, q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5) + # JLArray tests (CI-safe, no GPU hardware required) + @testset "JLArray" begin + using JLArrays + include(joinpath(@__DIR__, "jlarray", "test_stiefel.jl")) + include(joinpath(@__DIR__, "jlarray", "test_general_unitary_matrices.jl")) end - @testset "Stiefel retract_polar_fused fallback stress" begin - Random.seed!(48) - - M = Stiefel(48, 16) - MP = PowerManifold(M, 8) - t = 0.2 - - for _ in 1:4 - p = rand(MP) - X = 0.2 * rand(MP; vector_at = p) - - q = similar(p) - ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction()) - - p_cu = CuArray(p) - X_cu = CuArray(X) - q_cu = similar(p_cu) - ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction()) - q_cu_h = Array(q_cu) - - @test is_point(MP, q_cu_h) - @test isapprox(MP, p, q_cu_h, q; atol = 2.0e-12, rtol = 2.0e-12) + # CUDA tests (requires GPU hardware) + if CUDA.functional() + @testset "CUDA" begin + include(joinpath(@__DIR__, "cuda", "test_stiefel.jl")) + include(joinpath(@__DIR__, "cuda", "test_general_unitary_matrices.jl")) end + else + @info "CUDA not available, skipping CUDA tests" end end