Skip to content

Commit 85a7f30

Browse files
committed
add GPU retract_qr_fused! via Cholesky-QR for GeneralUnitaryMatrices, Stiefel, Grassmann
1 parent dfabf06 commit 85a7f30

9 files changed

Lines changed: 262 additions & 9 deletions

File tree

benchmarks/main.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
include(joinpath(@__DIR__, "utils.jl"))
99

10-
function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float32, t::Float32, samples::Int, seed::Int, point_type, methods)
10+
function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float32, t::Float32, samples::Int, seed::Int, point_type, methods, error_fn = nothing)
1111
data = _setup_data(M; batch = batch, scale = scale, seed = seed, point_type = point_type, use_power_manifold = true)
1212
manifold_label = "PowerManifold($name, $batch)"
1313
results = NamedTuple[]
@@ -17,7 +17,7 @@ function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float3
1717
println()
1818

1919
for method in methods
20-
push!(results, _benchmark_retraction(method; MP = data.MB, p_cpu = data.p_cpu, X_cpu = data.X_cpu, p_gpu = data.p_gpu, X_gpu = data.X_gpu, t = t, samples = samples, manifold_label = manifold_label))
20+
push!(results, _benchmark_retraction(method; MP = data.MB, p_cpu = data.p_cpu, X_cpu = data.X_cpu, p_gpu = data.p_gpu, X_gpu = data.X_gpu, t = t, samples = samples, manifold_label = manifold_label, error_fn = error_fn))
2121
println()
2222
end
2323

@@ -44,15 +44,15 @@ function main()
4444

4545
append!(all_results, benchmark_manifold("Rotations($n)", Rotations(n); batch = batch, scale = scale, samples = samples, seed = seed + 2, point_type = Float32))
4646

47-
append!(all_results, _benchmark_extra_retractions("Rotations($n)", Rotations(n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2, point_type = Float32, methods = [PolarRetraction()]))
47+
append!(all_results, _benchmark_extra_retractions("Rotations($n)", Rotations(n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2, point_type = Float32, methods = [PolarRetraction(), QRRetraction()]))
4848

4949
append!(all_results, benchmark_manifold("UnitaryMatrices($n)", UnitaryMatrices(n); batch = batch, scale = scale, samples = samples, seed = seed + 3, point_type = ComplexF32))
5050

5151
append!(all_results, benchmark_manifold("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, samples = samples, seed = seed + 4, point_type = Float32, exp_error_fn = _subspace_error))
5252

53-
append!(all_results, _benchmark_extra_retractions("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4, point_type = Float32, methods = [PolarRetraction()]))
53+
append!(all_results, _benchmark_extra_retractions("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4, point_type = Float32, methods = [PolarRetraction(), QRRetraction()], error_fn = _subspace_error))
5454

55-
append!(all_results, _benchmark_extra_retractions("Stiefel($n, $k)", Stiefel(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5, point_type = Float32, methods = [ExponentialRetraction(), PolarRetraction()]))
55+
append!(all_results, _benchmark_extra_retractions("Stiefel($n, $k)", Stiefel(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5, point_type = Float32, methods = [ExponentialRetraction(), PolarRetraction(), QRRetraction()]))
5656

5757
markdown_table = generate_markdown_summary_table(all_results)
5858
println("=== Markdown summary table ===")

benchmarks/utils.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ function _benchmark_retraction(
319319
t::Float32,
320320
samples::Int,
321321
manifold_label::String,
322+
error_fn = nothing,
322323
)
323324
q_cpu = similar(p_cpu)
324325
q_gpu = similar(p_gpu)
@@ -333,7 +334,8 @@ function _benchmark_retraction(
333334

334335
cpu_res = exp(MP, p_cpu, X_cpu)
335336
gpu_res = Array(CUDA.@sync exp(MP, p_gpu, X_gpu))
336-
relerr = _relative_error(cpu_res, gpu_res)
337+
relerr = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
338+
relerr_label = isnothing(error_fn) ? "||Ycpu - Ygpu||/||Ycpu||" : "distance(Ycpu, Ygpu)"
337339

338340
_print_results(
339341
name = method_name,
@@ -344,7 +346,7 @@ function _benchmark_retraction(
344346
cpu_ms = cpu_ms,
345347
gpu_ms = gpu_ms,
346348
relerr = relerr,
347-
relerr_label = "||Ycpu - Ygpu||/||Ycpu||",
349+
relerr_label = relerr_label,
348350
extra_lines = ["Retraction method: $method_name"],
349351
)
350352

@@ -366,7 +368,8 @@ function _benchmark_retraction(
366368

367369
cpu_res = ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, method)
368370
gpu_res = Array(CUDA.@sync ManifoldsBase.retract_fused!(MP, q_gpu, p_gpu, X_gpu, t, method))
369-
relerr = _relative_error(cpu_res, gpu_res)
371+
relerr = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
372+
relerr_label = isnothing(error_fn) ? "||Qcpu - Qgpu||/||Qcpu||" : "distance(Qcpu, Qgpu)"
370373

371374
_print_results(
372375
name = method_name,
@@ -377,7 +380,7 @@ function _benchmark_retraction(
377380
cpu_ms = cpu_ms,
378381
gpu_ms = gpu_ms,
379382
relerr = relerr,
380-
relerr_label = "||Qcpu - Qgpu||/||Qcpu||",
383+
relerr_label = relerr_label,
381384
extra_lines = ["Retraction scalar t: $t", "Retraction method: $method_name"],
382385
)
383386

ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,29 @@ function ManifoldsBase.retract_fused!(
5353
return ManifoldsBase.retract_polar_fused!(M, q, p, X, t)
5454
end
5555

56+
function ManifoldsBase.retract_qr_fused!(
57+
::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation},
58+
q::CuArray{T, 3},
59+
p::CuArray{T, 3},
60+
X::CuArray{T, 3},
61+
t::Number,
62+
) where {T <: Number}
63+
q .= p
64+
CUDA.CUBLAS.gemm_strided_batched!('N', 'N', T(t), p, X, one(T), q)
65+
return _cholesky_qr_gpu!(q)
66+
end
67+
68+
function ManifoldsBase.retract_fused!(
69+
M::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation},
70+
q::CuArray{T, 3},
71+
p::CuArray{T, 3},
72+
X::CuArray{T, 3},
73+
t::Number,
74+
::QRRetraction,
75+
) where {T <: Number}
76+
return ManifoldsBase.retract_qr_fused!(M, q, p, X, t)
77+
end
78+
5679
function ManifoldsBase.project!(
5780
::PowerManifold{
5881
<:Any,

ext/ManifoldsGPUCUDAExt/Grassmann.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,28 @@ function ManifoldsBase.retract_fused!(
8383
return ManifoldsBase.retract_polar_fused!(M, q, p, X, t)
8484
end
8585

86+
function ManifoldsBase.retract_qr_fused!(
87+
::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
88+
q::CuArray{T, 3},
89+
p::CuArray{T, 3},
90+
X::CuArray{T, 3},
91+
t::Number,
92+
) where {T <: Real}
93+
q .= p .+ t .* X
94+
return _cholesky_qr_gpu!(q)
95+
end
96+
97+
function ManifoldsBase.retract_fused!(
98+
M::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
99+
q::CuArray{T, 3},
100+
p::CuArray{T, 3},
101+
X::CuArray{T, 3},
102+
t::Number,
103+
::QRRetraction,
104+
) where {T <: Real}
105+
return ManifoldsBase.retract_qr_fused!(M, q, p, X, t)
106+
end
107+
86108
function ManifoldsBase.inverse_retract_polar!(
87109
::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
88110
X::CuArray{T, 3},

ext/ManifoldsGPUCUDAExt/Stiefel.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,25 @@ function ManifoldsBase.retract_fused!(
6969
) where {T <: Real}
7070
return ManifoldsBase.retract_polar_fused!(M, q, p, X, t)
7171
end
72+
73+
function ManifoldsBase.retract_qr_fused!(
74+
::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
75+
q::CuArray{T, 3},
76+
p::CuArray{T, 3},
77+
X::CuArray{T, 3},
78+
t::Number,
79+
) where {T <: Real}
80+
q .= p .+ t .* X
81+
return _cholesky_qr_gpu!(q)
82+
end
83+
84+
function ManifoldsBase.retract_fused!(
85+
M::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
86+
q::CuArray{T, 3},
87+
p::CuArray{T, 3},
88+
X::CuArray{T, 3},
89+
t::Number,
90+
::QRRetraction,
91+
) where {T <: Real}
92+
return ManifoldsBase.retract_qr_fused!(M, q, p, X, t)
93+
end

ext/ManifoldsGPUCUDAExt/helpers.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,22 @@ function _polar_project_gpu!(q::CuArray{T, 3}) where {T}
185185
end
186186
return q
187187
end
188+
189+
# In-place Cholesky-QR orthogonalization: Q = A * R⁻¹ where A'A = R'R.
190+
# Cholesky factor R has positive diagonal, matching CPU's sign-corrected Householder QR.
191+
function _cholesky_qr_gpu!(A::CuArray{T, 3}) where {T}
192+
batch = size(A, 3)
193+
194+
# Gram matrix G = A'A (k×k×batch where A is n×k×batch)
195+
G = CUDA.CUBLAS.gemm_strided_batched('C', 'N', A, A)
196+
197+
# Cholesky factorization: upper triangle of G becomes R where G = R'R
198+
G_slices = [view(G, :, :, i) for i in 1:batch]
199+
CUDA.CUSOLVER.potrfBatched!('U', G_slices)
200+
201+
# Right triangular solve: Q * R = A → Q = A * R⁻¹ (overwrites A with Q)
202+
A_slices = [view(A, :, :, i) for i in 1:batch]
203+
CUDA.CUBLAS.trsm_batched!('R', 'U', 'N', 'N', one(T), G_slices, A_slices)
204+
205+
return A
206+
end

test/cuda/test_general_unitary_matrices.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,73 @@
402402
@test isapprox(i_gpu, i_cpu; atol = 2.0f-4, rtol = 2.0f-4)
403403
@test isapprox(n_gpu, n_cpu; atol = 2.0f-4, rtol = 2.0f-4)
404404
end
405+
406+
@testset "Rotations retract_qr_fused Float64" begin
407+
Random.seed!(54)
408+
409+
M = Rotations(8)
410+
MP = PowerManifold(M, 64)
411+
t = 0.3
412+
413+
p = rand(MP)
414+
X = rand(MP; vector_at = p)
415+
416+
q = similar(p)
417+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
418+
419+
p_cu = CuArray(p)
420+
X_cu = CuArray(X)
421+
q_cu = similar(p_cu)
422+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
423+
q_cu_h = Array(q_cu)
424+
425+
@test is_point(MP, q_cu_h)
426+
@test isapprox(q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14)
427+
end
428+
429+
@testset "Rotations retract_qr_fused Float32" begin
430+
Random.seed!(55)
431+
432+
M = Rotations(8)
433+
MP = PowerManifold(M, 64)
434+
t = Float32(0.3)
435+
436+
p = Float32.(rand(MP))
437+
X = Float32.(rand(MP; vector_at = p))
438+
439+
q = similar(p)
440+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
441+
442+
p_cu = CuArray(p)
443+
X_cu = CuArray(X)
444+
q_cu = similar(p_cu)
445+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
446+
q_cu_h = Array(q_cu)
447+
448+
@test is_point(MP, q_cu_h)
449+
@test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5)
450+
end
451+
452+
@testset "UnitaryMatrices retract_qr_fused ComplexF64" begin
453+
Random.seed!(56)
454+
455+
M = UnitaryMatrices(8)
456+
MP = PowerManifold(M, 64)
457+
t = 0.3
458+
459+
p = rand(MP)
460+
X = 0.25 .* rand(MP; vector_at = p)
461+
462+
q = similar(p)
463+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
464+
465+
p_cu = CuArray(p)
466+
X_cu = CuArray(X)
467+
q_cu = similar(p_cu)
468+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
469+
q_cu_h = Array(q_cu)
470+
471+
@test is_point(MP, q_cu_h)
472+
@test isapprox(q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14)
473+
end
405474
end

test/cuda/test_grassmann.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,55 @@
259259
@test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5)
260260
end
261261

262+
# GPU uses Cholesky-QR, CPU uses Householder QR — same subspace, different matrix.
263+
# Compare via distance (same pattern as exp! tests above).
264+
265+
@testset "retract_qr_fused Float64" begin
266+
Random.seed!(89)
267+
268+
M = Grassmann(8, 4)
269+
MP = PowerManifold(M, 32)
270+
t = 0.3
271+
272+
p = rand(MP)
273+
X = rand(MP; vector_at = p)
274+
275+
q = similar(p)
276+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
277+
278+
p_cu = CuArray(p)
279+
X_cu = CuArray(X)
280+
q_cu = similar(p_cu)
281+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
282+
q_cu_h = Array(q_cu)
283+
284+
@test is_point(MP, q_cu_h)
285+
@test distance(MP, q_cu_h, q) < 2.0e-14
286+
end
287+
288+
@testset "retract_qr_fused Float32" begin
289+
Random.seed!(891)
290+
291+
M = Grassmann(8, 4)
292+
MP = PowerManifold(M, 32)
293+
t = Float32(0.3)
294+
295+
p = Float32.(rand(MP))
296+
X = Float32.(rand(MP; vector_at = p))
297+
298+
q = similar(p)
299+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
300+
301+
p_cu = CuArray(p)
302+
X_cu = CuArray(X)
303+
q_cu = similar(p_cu)
304+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
305+
q_cu_h = Array(q_cu)
306+
307+
@test is_point(MP, q_cu_h)
308+
@test distance(MP, q_cu_h, q) < 2.0f-5
309+
end
310+
262311
@testset "inverse_retract_polar Float64" begin
263312
Random.seed!(90)
264313

test/cuda/test_stiefel.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,50 @@
173173
@test isapprox(i_gpu, i_cpu; atol = 1.0f-4, rtol = 1.0f-4)
174174
@test isapprox(n_gpu, n_cpu; atol = 1.0f-4, rtol = 1.0f-4)
175175
end
176+
177+
@testset "retract_qr_fused Float64" begin
178+
Random.seed!(57)
179+
180+
M = Stiefel(8, 4)
181+
MP = PowerManifold(M, 64)
182+
t = 0.3
183+
184+
p = rand(MP)
185+
X = rand(MP; vector_at = p)
186+
187+
q = similar(p)
188+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
189+
190+
p_cu = CuArray(p)
191+
X_cu = CuArray(X)
192+
q_cu = similar(p_cu)
193+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
194+
q_cu_h = Array(q_cu)
195+
196+
@test is_point(MP, q_cu_h)
197+
@test isapprox(q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14)
198+
end
199+
200+
@testset "retract_qr_fused Float32" begin
201+
Random.seed!(58)
202+
203+
M = Stiefel(8, 4)
204+
MP = PowerManifold(M, 64)
205+
t = Float32(0.3)
206+
207+
p = Float32.(rand(MP))
208+
X = Float32.(rand(MP; vector_at = p))
209+
210+
q = similar(p)
211+
ManifoldsBase.retract_fused!(MP, q, p, X, t, QRRetraction())
212+
213+
p_cu = CuArray(p)
214+
X_cu = CuArray(X)
215+
q_cu = similar(p_cu)
216+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, QRRetraction())
217+
q_cu_h = Array(q_cu)
218+
219+
@test is_point(MP, q_cu_h)
220+
@test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5)
221+
end
176222
end

0 commit comments

Comments
 (0)