Skip to content

Commit 667dd3d

Browse files
authored
Merge pull request #11 from zazabap/feat/cholesky-qr-retraction-squashed
add GPU retract_qr_fused! via Cholesky-QR
2 parents d5943c5 + c26e511 commit 667dd3d

10 files changed

Lines changed: 309 additions & 48 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ManifoldsGPU"
22
uuid = "007d1224-8888-47ee-87d0-87e096ff9b5b"
33
version = "0.1.0-DEV"
4-
authors = ["Mateusz Baran <mateuszbaran89@gmail.com> and contributors"]
4+
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>", "Shiwen An <sweynan@icloud.com>", "and contributors"]
55

66
[deps]
77
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

benchmarks/main.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
include(joinpath(@__DIR__, "utils.jl"))
99

10-
function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float32, t::Float32, samples::Int, seed::Int, point_type, methods)
10+
function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float32, t::Float32, samples::Int, seed::Int, point_type, methods, error_fn = nothing)
1111
data = _setup_data(M; batch = batch, scale = scale, seed = seed, point_type = point_type, use_power_manifold = true)
1212
manifold_label = "PowerManifold($name, $batch)"
1313
results = NamedTuple[]
@@ -17,7 +17,7 @@ function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float3
1717
println()
1818

1919
for method in methods
20-
push!(results, _benchmark_retraction(method; MP = data.MB, p_cpu = data.p_cpu, X_cpu = data.X_cpu, p_gpu = data.p_gpu, X_gpu = data.X_gpu, t = t, samples = samples, manifold_label = manifold_label))
20+
push!(results, _benchmark_retraction(method; MP = data.MB, p_cpu = data.p_cpu, X_cpu = data.X_cpu, p_gpu = data.p_gpu, X_gpu = data.X_gpu, t = t, samples = samples, manifold_label = manifold_label, error_fn = error_fn))
2121
println()
2222
end
2323

@@ -44,15 +44,15 @@ function main()
4444

4545
append!(all_results, benchmark_manifold("Rotations($n)", Rotations(n); batch = batch, scale = scale, samples = samples, seed = seed + 2, point_type = Float32))
4646

47-
append!(all_results, _benchmark_extra_retractions("Rotations($n)", Rotations(n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2, point_type = Float32, methods = [PolarRetraction()]))
47+
append!(all_results, _benchmark_extra_retractions("Rotations($n)", Rotations(n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2, point_type = Float32, methods = [PolarRetraction(), QRRetraction()]))
4848

4949
append!(all_results, benchmark_manifold("UnitaryMatrices($n)", UnitaryMatrices(n); batch = batch, scale = scale, samples = samples, seed = seed + 3, point_type = ComplexF32))
5050

5151
append!(all_results, benchmark_manifold("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, samples = samples, seed = seed + 4, point_type = Float32, exp_error_fn = _subspace_error))
5252

53-
append!(all_results, _benchmark_extra_retractions("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4, point_type = Float32, methods = [PolarRetraction()]))
53+
append!(all_results, _benchmark_extra_retractions("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4, point_type = Float32, methods = [PolarRetraction(), QRRetraction()], error_fn = _subspace_error))
5454

55-
append!(all_results, _benchmark_extra_retractions("Stiefel($n, $k)", Stiefel(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5, point_type = Float32, methods = [ExponentialRetraction(), PolarRetraction()]))
55+
append!(all_results, _benchmark_extra_retractions("Stiefel($n, $k)", Stiefel(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5, point_type = Float32, methods = [ExponentialRetraction(), PolarRetraction(), QRRetraction()]))
5656

5757
markdown_table = generate_markdown_summary_table(all_results)
5858
println("=== Markdown summary table ===")

benchmarks/utils.jl

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ function _print_results(;
4040
gpu_all,
4141
cpu_ms::Float64,
4242
gpu_ms::Float64,
43-
relerr,
44-
relerr_label::String,
43+
err,
44+
err_label::String,
4545
extra_lines::Vector{String} = String[],
4646
)
4747
speedup = cpu_ms / gpu_ms
@@ -56,10 +56,10 @@ function _print_results(;
5656
println("Median CPU [ms]: ", round(cpu_ms; digits = 2))
5757
println("Median GPU [ms]: ", round(gpu_ms; digits = 2))
5858
println("Speedup (CPU/GPU): ", round(speedup; digits = 2), "x")
59-
return println("Relative error $relerr_label: ", relerr)
59+
return println("Error $err_label: ", err)
6060
end
6161

62-
function _benchmark_result(; manifold_label::String, operation::String, samples::Int, cpu_ms::Float64, gpu_ms::Float64, relerr)
62+
function _benchmark_result(; manifold_label::String, operation::String, samples::Int, cpu_ms::Float64, gpu_ms::Float64, err)
6363
speedup = gpu_ms == 0.0 ? Inf : cpu_ms / gpu_ms
6464
return (
6565
manifold = manifold_label,
@@ -68,24 +68,24 @@ function _benchmark_result(; manifold_label::String, operation::String, samples:
6868
cpu_ms = cpu_ms,
6969
gpu_ms = gpu_ms,
7070
speedup = speedup,
71-
relerr = relerr,
71+
err = err,
7272
)
7373
end
7474

7575
function generate_markdown_summary_table(results)
7676
lines = String[
77-
"| Manifold | Operation | CPU median [ms] | GPU median [ms] | Speedup CPU/GPU | Relative error |",
77+
"| Manifold | Operation | CPU median [ms] | GPU median [ms] | Speedup CPU/GPU | Error |",
7878
"| --- | --- | ---: | ---: | ---: | ---: |",
7979
]
8080

8181
for r in results
8282
cpu_s = string(round(r.cpu_ms; digits = 2))
8383
gpu_s = string(round(r.gpu_ms; digits = 2))
8484
speedup_s = string(round(r.speedup; digits = 2))
85-
relerr_s = string(round(Float64(r.relerr); sigdigits = 4))
85+
err_s = string(round(Float64(r.err); sigdigits = 4))
8686
push!(
8787
lines,
88-
"| $(r.manifold) | $(r.operation) | $cpu_s | $gpu_s | $speedup_s | $relerr_s |",
88+
"| $(r.manifold) | $(r.operation) | $cpu_s | $gpu_s | $speedup_s | $err_s |",
8989
)
9090
end
9191

@@ -148,7 +148,17 @@ end
148148

149149
# --- Generic operation benchmarks ---
150150

151-
function _benchmark_exp(; MP, p_cpu, X_cpu, p_gpu, X_gpu, samples::Int, manifold_label::String, error_fn = nothing)
151+
function _benchmark_exp(;
152+
MP,
153+
p_cpu,
154+
X_cpu,
155+
p_gpu,
156+
X_gpu,
157+
samples::Int,
158+
manifold_label::String,
159+
error_fn = nothing,
160+
err_label::String = isnothing(error_fn) ? "||qcpu - qgpu||/||qcpu||" : "distance(qcpu, qgpu)",
161+
)
152162
cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu(
153163
() -> exp(MP, p_cpu, X_cpu),
154164
() -> CUDA.@sync exp(MP, p_gpu, X_gpu);
@@ -157,8 +167,7 @@ function _benchmark_exp(; MP, p_cpu, X_cpu, p_gpu, X_gpu, samples::Int, manifold
157167

158168
cpu_res = exp(MP, p_cpu, X_cpu)
159169
gpu_res = Array(CUDA.@sync exp(MP, p_gpu, X_gpu))
160-
relerr = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
161-
relerr_label = isnothing(error_fn) ? "||Ycpu - Ygpu||/||Ycpu||" : "distance(Ycpu, Ygpu)"
170+
err = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
162171

163172
_print_results(
164173
name = "exp",
@@ -168,16 +177,16 @@ function _benchmark_exp(; MP, p_cpu, X_cpu, p_gpu, X_gpu, samples::Int, manifold
168177
gpu_all = gpu_all,
169178
cpu_ms = cpu_ms,
170179
gpu_ms = gpu_ms,
171-
relerr = relerr,
172-
relerr_label = relerr_label,
180+
err = err,
181+
err_label = err_label,
173182
)
174183
return _benchmark_result(
175184
manifold_label = manifold_label,
176185
operation = "exp",
177186
samples = samples,
178187
cpu_ms = cpu_ms,
179188
gpu_ms = gpu_ms,
180-
relerr = relerr,
189+
err = err,
181190
)
182191
end
183192

@@ -190,7 +199,7 @@ function _benchmark_log!(; MP, p_cpu, q_cpu, p_gpu, q_gpu, X_cpu, X_gpu, samples
190199

191200
cpu_res = log!(MP, X_cpu, p_cpu, q_cpu)
192201
gpu_res = Array(CUDA.@sync log!(MP, X_gpu, p_gpu, q_gpu))
193-
relerr = _relative_error(cpu_res, gpu_res)
202+
err = _relative_error(cpu_res, gpu_res)
194203

195204
_print_results(
196205
name = "log!",
@@ -200,16 +209,16 @@ function _benchmark_log!(; MP, p_cpu, q_cpu, p_gpu, q_gpu, X_cpu, X_gpu, samples
200209
gpu_all = gpu_all,
201210
cpu_ms = cpu_ms,
202211
gpu_ms = gpu_ms,
203-
relerr = relerr,
204-
relerr_label = "||Xcpu - Xgpu||/||Xcpu||",
212+
err = err,
213+
err_label = "||Xcpu - Xgpu||/||Xcpu||",
205214
)
206215
return _benchmark_result(
207216
manifold_label = manifold_label,
208217
operation = "log!",
209218
samples = samples,
210219
cpu_ms = cpu_ms,
211220
gpu_ms = gpu_ms,
212-
relerr = relerr,
221+
err = err,
213222
)
214223
end
215224

@@ -222,7 +231,7 @@ function _benchmark_inner(; MP, p_cpu, X_cpu, Y_cpu, p_gpu, X_gpu, Y_gpu, sample
222231

223232
cpu_res = inner(MP, p_cpu, X_cpu, Y_cpu)
224233
gpu_res = CUDA.@sync inner(MP, p_gpu, X_gpu, Y_gpu)
225-
relerr = _relative_error(cpu_res, gpu_res)
234+
err = _relative_error(cpu_res, gpu_res)
226235

227236
_print_results(
228237
name = "inner",
@@ -232,16 +241,16 @@ function _benchmark_inner(; MP, p_cpu, X_cpu, Y_cpu, p_gpu, X_gpu, Y_gpu, sample
232241
gpu_all = gpu_all,
233242
cpu_ms = cpu_ms,
234243
gpu_ms = gpu_ms,
235-
relerr = relerr,
236-
relerr_label = "|icpu - igpu|/|icpu|",
244+
err = err,
245+
err_label = "|icpu - igpu|/|icpu|",
237246
)
238247
return _benchmark_result(
239248
manifold_label = manifold_label,
240249
operation = "inner",
241250
samples = samples,
242251
cpu_ms = cpu_ms,
243252
gpu_ms = gpu_ms,
244-
relerr = relerr,
253+
err = err,
245254
)
246255
end
247256

@@ -254,7 +263,7 @@ function _benchmark_norm(; MP, p_cpu, X_cpu, p_gpu, X_gpu, samples::Int, manifol
254263

255264
cpu_res = norm(MP, p_cpu, X_cpu)
256265
gpu_res = CUDA.@sync norm(MP, p_gpu, X_gpu)
257-
relerr = _relative_error(cpu_res, gpu_res)
266+
err = _relative_error(cpu_res, gpu_res)
258267

259268
_print_results(
260269
name = "norm",
@@ -264,16 +273,16 @@ function _benchmark_norm(; MP, p_cpu, X_cpu, p_gpu, X_gpu, samples::Int, manifol
264273
gpu_all = gpu_all,
265274
cpu_ms = cpu_ms,
266275
gpu_ms = gpu_ms,
267-
relerr = relerr,
268-
relerr_label = "|ncpu - ngpu|/|ncpu|",
276+
err = err,
277+
err_label = "|ncpu - ngpu|/|ncpu|",
269278
)
270279
return _benchmark_result(
271280
manifold_label = manifold_label,
272281
operation = "norm",
273282
samples = samples,
274283
cpu_ms = cpu_ms,
275284
gpu_ms = gpu_ms,
276-
relerr = relerr,
285+
err = err,
277286
)
278287
end
279288

@@ -286,7 +295,7 @@ function _benchmark_project!(; MP, p_cpu, Z_cpu, p_gpu, Z_gpu, X_cpu, X_gpu, sam
286295

287296
cpu_res = project!(MP, X_cpu, p_cpu, Z_cpu)
288297
gpu_res = Array(CUDA.@sync project!(MP, X_gpu, p_gpu, Z_gpu))
289-
relerr = _relative_error(cpu_res, gpu_res)
298+
err = _relative_error(cpu_res, gpu_res)
290299

291300
_print_results(
292301
name = "project!",
@@ -296,16 +305,16 @@ function _benchmark_project!(; MP, p_cpu, Z_cpu, p_gpu, Z_gpu, X_cpu, X_gpu, sam
296305
gpu_all = gpu_all,
297306
cpu_ms = cpu_ms,
298307
gpu_ms = gpu_ms,
299-
relerr = relerr,
300-
relerr_label = "||Xcpu - Xgpu||/||Xcpu||",
308+
err = err,
309+
err_label = "||Xcpu - Xgpu||/||Xcpu||",
301310
)
302311
return _benchmark_result(
303312
manifold_label = manifold_label,
304313
operation = "project!",
305314
samples = samples,
306315
cpu_ms = cpu_ms,
307316
gpu_ms = gpu_ms,
308-
relerr = relerr,
317+
err = err,
309318
)
310319
end
311320

@@ -319,6 +328,8 @@ function _benchmark_retraction(
319328
t::Float32,
320329
samples::Int,
321330
manifold_label::String,
331+
error_fn = nothing,
332+
err_label::String = isnothing(error_fn) ? "||qcpu - qgpu||/||qcpu||" : "distance(qcpu, qgpu)",
322333
)
323334
q_cpu = similar(p_cpu)
324335
q_gpu = similar(p_gpu)
@@ -333,7 +344,7 @@ function _benchmark_retraction(
333344

334345
cpu_res = exp(MP, p_cpu, X_cpu)
335346
gpu_res = Array(CUDA.@sync exp(MP, p_gpu, X_gpu))
336-
relerr = _relative_error(cpu_res, gpu_res)
347+
err = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
337348

338349
_print_results(
339350
name = method_name,
@@ -343,8 +354,8 @@ function _benchmark_retraction(
343354
gpu_all = gpu_all,
344355
cpu_ms = cpu_ms,
345356
gpu_ms = gpu_ms,
346-
relerr = relerr,
347-
relerr_label = "||Ycpu - Ygpu||/||Ycpu||",
357+
err = err,
358+
err_label = err_label,
348359
extra_lines = ["Retraction method: $method_name"],
349360
)
350361

@@ -354,7 +365,7 @@ function _benchmark_retraction(
354365
samples = samples,
355366
cpu_ms = cpu_ms,
356367
gpu_ms = gpu_ms,
357-
relerr = relerr,
368+
err = err,
358369
)
359370
end
360371

@@ -366,7 +377,7 @@ function _benchmark_retraction(
366377

367378
cpu_res = ManifoldsBase.retract_fused!(MP, q_cpu, p_cpu, X_cpu, t, method)
368379
gpu_res = Array(CUDA.@sync ManifoldsBase.retract_fused!(MP, q_gpu, p_gpu, X_gpu, t, method))
369-
relerr = _relative_error(cpu_res, gpu_res)
380+
err = isnothing(error_fn) ? _relative_error(cpu_res, gpu_res) : error_fn(MP, cpu_res, gpu_res)
370381

371382
_print_results(
372383
name = method_name,
@@ -376,8 +387,8 @@ function _benchmark_retraction(
376387
gpu_all = gpu_all,
377388
cpu_ms = cpu_ms,
378389
gpu_ms = gpu_ms,
379-
relerr = relerr,
380-
relerr_label = "||Qcpu - Qgpu||/||Qcpu||",
390+
err = err,
391+
err_label = err_label,
381392
extra_lines = ["Retraction scalar t: $t", "Retraction method: $method_name"],
382393
)
383394

@@ -387,7 +398,7 @@ function _benchmark_retraction(
387398
samples = samples,
388399
cpu_ms = cpu_ms,
389400
gpu_ms = gpu_ms,
390-
relerr = relerr,
401+
err = err,
391402
)
392403
end
393404

@@ -400,7 +411,7 @@ function _benchmark_distance(; MP, p_cpu, q_cpu, p_gpu, q_gpu, samples::Int, man
400411

401412
cpu_res = distance(MP, p_cpu, q_cpu)
402413
gpu_res = CUDA.@sync distance(MP, p_gpu, q_gpu)
403-
relerr = _relative_error(cpu_res, gpu_res)
414+
err = _relative_error(cpu_res, gpu_res)
404415

405416
_print_results(
406417
name = "distance",
@@ -410,16 +421,16 @@ function _benchmark_distance(; MP, p_cpu, q_cpu, p_gpu, q_gpu, samples::Int, man
410421
gpu_all = gpu_all,
411422
cpu_ms = cpu_ms,
412423
gpu_ms = gpu_ms,
413-
relerr = relerr,
414-
relerr_label = "|dcpu - dgpu|/|dcpu|",
424+
err = err,
425+
err_label = "|dcpu - dgpu|/|dcpu|",
415426
)
416427
return _benchmark_result(
417428
manifold_label = manifold_label,
418429
operation = "distance",
419430
samples = samples,
420431
cpu_ms = cpu_ms,
421432
gpu_ms = gpu_ms,
422-
relerr = relerr,
433+
err = err,
423434
)
424435
end
425436

ext/ManifoldsGPUCUDAExt/GeneralUnitaryMatrices.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,29 @@ function ManifoldsBase.retract_fused!(
5353
return ManifoldsBase.retract_polar_fused!(M, q, p, X, t)
5454
end
5555

56+
function ManifoldsBase.retract_qr_fused!(
57+
::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation},
58+
q::CuArray{T, 3},
59+
p::CuArray{T, 3},
60+
X::CuArray{T, 3},
61+
t::Number,
62+
) where {T <: Number}
63+
q .= p
64+
CUDA.CUBLAS.gemm_strided_batched!('N', 'N', T(t), p, X, one(T), q)
65+
return _cholesky_qr_gpu!(q)
66+
end
67+
68+
function ManifoldsBase.retract_fused!(
69+
M::PowerManifold{<:Any, <:Manifolds.GeneralUnitaryMatrices, <:Tuple, ArrayPowerRepresentation},
70+
q::CuArray{T, 3},
71+
p::CuArray{T, 3},
72+
X::CuArray{T, 3},
73+
t::Number,
74+
::QRRetraction,
75+
) where {T <: Number}
76+
return ManifoldsBase.retract_qr_fused!(M, q, p, X, t)
77+
end
78+
5679
function ManifoldsBase.project!(
5780
::PowerManifold{
5881
<:Any,

0 commit comments

Comments
 (0)