Skip to content

Commit 966303d

Browse files
committed
try QR retraction on Stiefel
1 parent 15da7e3 commit 966303d

5 files changed

Lines changed: 225 additions & 18 deletions

File tree

.github/workflows/format.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Runic formatting
2+
on:
3+
push:
4+
branches:
5+
- 'main'
6+
- 'release-'
7+
tags:
8+
- '*'
9+
pull_request:
10+
jobs:
11+
runic:
12+
name: Runic
13+
runs-on: ubuntu-latest
14+
# Permissions needed for reviewdog/action-suggester to post comments
15+
permissions:
16+
contents: read
17+
checks: write
18+
issues: write
19+
pull-requests: write
20+
steps:
21+
- uses: actions/checkout@v6
22+
# - uses: julia-actions/setup-julia@v2
23+
# with:
24+
# version: '1'
25+
# - uses: julia-actions/cache@v2
26+
- uses: fredrikekre/runic-action@v1
27+
with:
28+
version: '1'
29+
format_files: true
30+
# Fail on next step instead
31+
continue-on-error: ${{ github.event_name == 'pull_request' }}
32+
- uses: reviewdog/action-suggester@v1
33+
if: github.event_name == 'pull_request'
34+
with:
35+
tool_name: Runic
36+
fail_level: warning

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
# ManifoldsGPU
22

33
General GPU/CUDA support for the JuliaManifolds ecosystem.
4+
5+
The package is in early stages of development, and the API is not yet stable.
6+
7+
Notes:
8+
9+
- `exp!` on `PowerManifold(Stiefel(32, 16), 2048)` is about 20x faster on CUDA.
10+
- QR decomposition doesn't seem to be particularly fast on GPU. Q matrix formation can't even be batched as of Feburary 2026.

benchmarks/main.jl

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Statistics
33

44
using ManifoldsGPU
55
using Manifolds
6+
using ManifoldsBase
67
using CUDA
78

89
function _time_median(f; samples::Int = 6)
@@ -16,7 +17,7 @@ function _time_median(f; samples::Int = 6)
1617
return median(timings), timings
1718
end
1819

19-
function benchmark_stiefel_exp(; n::Int = 32, k::Int = 16, batch::Int = 2048, scale::Float32 = 0.2f0, samples::Int = 6, seed::Int = 1234)
20+
function _setup_stiefel_data(; n::Int, k::Int, batch::Int, scale::Float32, seed::Int)
2021
Random.seed!(seed)
2122

2223
M = Stiefel(n, k)
@@ -28,33 +29,109 @@ function benchmark_stiefel_exp(; n::Int = 32, k::Int = 16, batch::Int = 2048, sc
2829
p_gpu = CuArray(p_cpu)
2930
X_gpu = CuArray(X_cpu)
3031

31-
exp(MP, p_cpu, X_cpu)
32-
CUDA.@sync exp(MP, p_gpu, X_gpu)
32+
return (; MP, p_cpu, X_cpu, p_gpu, X_gpu)
33+
end
3334

34-
cpu_ms, cpu_all = _time_median(; samples = samples) do
35-
exp(MP, p_cpu, X_cpu)
36-
end
35+
function _benchmark_cpu_gpu(cpu_f, gpu_f; samples::Int)
36+
cpu_f()
37+
gpu_f()
3738

38-
gpu_ms, gpu_all = _time_median(; samples = samples) do
39-
CUDA.@sync exp(MP, p_gpu, X_gpu)
40-
end
39+
cpu_ms, cpu_all = _time_median(cpu_f; samples = samples)
40+
gpu_ms, gpu_all = _time_median(gpu_f; samples = samples)
4141

42+
return cpu_ms, cpu_all, gpu_ms, gpu_all
43+
end
44+
45+
function _print_results(; name::String, n::Int, k::Int, batch::Int, samples::Int, cpu_all, gpu_all, cpu_ms::Float64, gpu_ms::Float64, relerr, relerr_label::String, extra_lines::Vector{String} = String[])
4246
speedup = cpu_ms / gpu_ms
43-
relerr = begin
44-
Y_cpu = exp(MP, p_cpu, X_cpu)
45-
Y_gpu = Array(CUDA.@sync exp(MP, p_gpu, X_gpu))
46-
norm(Y_cpu .- Y_gpu) / max(norm(Y_cpu), eps(Float32))
47-
end
4847

49-
println("=== ManifoldsGPU benchmark: exp on PowerManifold(Stiefel($n, $k), $batch) ===")
48+
println("=== ManifoldsGPU benchmark: $name on PowerManifold(Stiefel($n, $k), $batch) ===")
5049
println("Element type: Float32")
50+
for line in extra_lines
51+
println(line)
52+
end
5153
println("Samples: $samples")
5254
println("CPU times [ms]: ", round.(cpu_all; digits = 2))
5355
println("GPU times [ms]: ", round.(gpu_all; digits = 2))
5456
println("Median CPU [ms]: ", round(cpu_ms; digits = 2))
5557
println("Median GPU [ms]: ", round(gpu_ms; digits = 2))
5658
println("Speedup (CPU/GPU): ", round(speedup; digits = 2), "x")
57-
return println("Relative error ||Ycpu - Ygpu||/||Ycpu||: ", relerr)
59+
return println("Relative error $relerr_label: ", relerr)
60+
end
61+
62+
function benchmark_stiefel_exp(; n::Int = 32, k::Int = 16, batch::Int = 2048, scale::Float32 = 0.2f0, samples::Int = 6, seed::Int = 1234)
63+
data = _setup_stiefel_data(; n = n, k = k, batch = batch, scale = scale, seed = seed)
64+
MP = data.MP
65+
p_cpu = data.p_cpu
66+
X_cpu = data.X_cpu
67+
p_gpu = data.p_gpu
68+
X_gpu = data.X_gpu
69+
70+
cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu(
71+
() -> exp(MP, p_cpu, X_cpu),
72+
() -> CUDA.@sync exp(MP, p_gpu, X_gpu);
73+
samples = samples,
74+
)
75+
76+
relerr = begin
77+
Y_cpu = exp(MP, p_cpu, X_cpu)
78+
Y_gpu = Array(CUDA.@sync exp(MP, p_gpu, X_gpu))
79+
norm(Y_cpu .- Y_gpu) / max(norm(Y_cpu), eps(Float32))
80+
end
81+
82+
return _print_results(
83+
name = "exp",
84+
n = n,
85+
k = k,
86+
batch = batch,
87+
samples = samples,
88+
cpu_all = cpu_all,
89+
gpu_all = gpu_all,
90+
cpu_ms = cpu_ms,
91+
gpu_ms = gpu_ms,
92+
relerr = relerr,
93+
relerr_label = "||Ycpu - Ygpu||/||Ycpu||",
94+
)
95+
end
96+
97+
function benchmark_stiefel_retract_qr_fused(; n::Int = 32, k::Int = 16, batch::Int = 2048, scale::Float32 = 0.2f0, t::Float32 = 0.3f0, samples::Int = 6, seed::Int = 1234)
98+
data = _setup_stiefel_data(; n = n, k = k, batch = batch, scale = scale, seed = seed)
99+
MP = data.MP
100+
p_cpu = data.p_cpu
101+
X_cpu = data.X_cpu
102+
p_gpu = data.p_gpu
103+
X_gpu = data.X_gpu
104+
105+
q_cpu = similar(p_cpu)
106+
q_gpu = similar(p_gpu)
107+
108+
cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu(
109+
() -> ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, QRRetraction()),
110+
() -> CUDA.@sync ManifoldsBase.retract_fused!(MP, q_gpu, p_gpu, X_gpu, t, QRRetraction());
111+
samples = samples,
112+
)
113+
114+
relerr = begin
115+
ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, QRRetraction())
116+
CUDA.@sync ManifoldsBase.retract_fused!(MP, q_gpu, p_gpu, X_gpu, t, QRRetraction())
117+
q_gpu_h = Array(q_gpu)
118+
norm(q_cpu .- q_gpu_h) / max(norm(q_cpu), eps(Float32))
119+
end
120+
121+
return _print_results(
122+
name = "retract_qr_fused",
123+
n = n,
124+
k = k,
125+
batch = batch,
126+
samples = samples,
127+
cpu_all = cpu_all,
128+
gpu_all = gpu_all,
129+
cpu_ms = cpu_ms,
130+
gpu_ms = gpu_ms,
131+
relerr = relerr,
132+
relerr_label = "||Qcpu - Qgpu||/||Qcpu||",
133+
extra_lines = ["Retraction scalar t: $t"],
134+
)
58135
end
59136

60137
function _parse_arg(i::Int, default)
@@ -68,8 +145,10 @@ function main()
68145
samples = _parse_arg(4, 6)
69146

70147
println("Running with n=$n, k=$k, batch=$batch, samples=$samples")
71-
72-
return benchmark_stiefel_exp(; n = n, k = k, batch = batch, samples = samples)
148+
println()
149+
benchmark_stiefel_exp(; n = n, k = k, batch = batch, samples = samples)
150+
println()
151+
return benchmark_stiefel_retract_qr_fused(; n = n, k = k, batch = batch, samples = samples)
73152
end
74153

75154
main()

ext/ManifoldsGPUCUDAExt/Stiefel.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,42 @@ function ManifoldsBase.exp!(
3030

3131
return q
3232
end
33+
34+
function ManifoldsBase.retract_qr_fused!(
35+
M::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
36+
q::CuArray{T, 3},
37+
p::CuArray{T, 3},
38+
X::CuArray{T, 3},
39+
t::Number,
40+
) where {T <: Real}
41+
_, k = ManifoldsBase.get_parameter(M.manifold.size)
42+
batch = size(q, 3)
43+
44+
q .= p .+ t .* X
45+
46+
q_views = [@view q[:, :, i] for i in 1:batch]
47+
tau, q_factors = CUDA.CUBLAS.geqrf_batched!(q_views)
48+
49+
for i in 1:batch
50+
q_factor_cpu = Array(q_factors[i])
51+
tau_cpu = Array(tau[i])
52+
d = diag(@view(q_factor_cpu[1:k, 1:k]))
53+
s = sign.(sign.(d .+ T(1 // 2)))
54+
LinearAlgebra.LAPACK.orgqr!(q_factor_cpu, tau_cpu)
55+
q_factor_cpu .*= reshape(s, 1, k)
56+
copyto!(q_factors[i], q_factor_cpu)
57+
end
58+
59+
return q
60+
end
61+
62+
function ManifoldsBase.retract_fused!(
63+
M::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
64+
q::CuArray{T, 3},
65+
p::CuArray{T, 3},
66+
X::CuArray{T, 3},
67+
t::Number,
68+
::QRRetraction,
69+
) where {T <: Real}
70+
return ManifoldsBase.retract_qr_fused!(M, q, p, X, t)
71+
end

test/runtests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,50 @@ using CUDA
6363
@test isapprox(MP, p, Y_cu_h, Y; atol = 2.0f-5, rtol = 2.0f-5)
6464
end
6565
end
66+
67+
@testset "Stiefel retract_qr_fused batched" begin
68+
Random.seed!(44)
69+
70+
M = Stiefel(8, 4)
71+
MP = PowerManifold(M, 64)
72+
t = 0.3
73+
74+
p = rand(MP)
75+
X = rand(MP; vector_at = p)
76+
77+
q = similar(p)
78+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
79+
80+
p_cu = CuArray(p)
81+
X_cu = CuArray(X)
82+
q_cu = similar(p_cu)
83+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
84+
q_cu_h = Array(q_cu)
85+
86+
@test is_point(MP, q_cu_h)
87+
@test isapprox(MP, p, q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14)
88+
end
89+
90+
@testset "Stiefel retract_qr_fused batched Float32" begin
91+
Random.seed!(45)
92+
93+
M = Stiefel(8, 4)
94+
MP = PowerManifold(M, 64)
95+
t = Float32(0.3)
96+
97+
p = Float32.(rand(MP))
98+
X = Float32.(rand(MP; vector_at = p))
99+
100+
q = similar(p)
101+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
102+
103+
p_cu = CuArray(p)
104+
X_cu = CuArray(X)
105+
q_cu = similar(p_cu)
106+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
107+
q_cu_h = Array(q_cu)
108+
109+
@test is_point(MP, q_cu_h)
110+
@test isapprox(MP, p, q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5)
111+
end
66112
end

0 commit comments

Comments
 (0)