@@ -3,6 +3,7 @@ using Statistics
33
44using ManifoldsGPU
55using Manifolds
6+ using ManifoldsBase
67using CUDA
78
89function _time_median (f; samples:: Int = 6 )
@@ -16,7 +17,7 @@ function _time_median(f; samples::Int = 6)
1617 return median (timings), timings
1718end
1819
19- function benchmark_stiefel_exp (; n:: Int = 32 , k:: Int = 16 , batch:: Int = 2048 , scale:: Float32 = 0.2f0 , samples :: Int = 6 , seed:: Int = 1234 )
20+ function _setup_stiefel_data (; n:: Int , k:: Int , batch:: Int , scale:: Float32 , seed:: Int )
2021 Random. seed! (seed)
2122
2223 M = Stiefel (n, k)
@@ -28,33 +29,109 @@ function benchmark_stiefel_exp(; n::Int = 32, k::Int = 16, batch::Int = 2048, sc
2829 p_gpu = CuArray (p_cpu)
2930 X_gpu = CuArray (X_cpu)
3031
31- exp ( MP, p_cpu, X_cpu)
32- CUDA . @sync exp (MP, p_gpu, X_gpu)
32+ return (; MP, p_cpu, X_cpu, p_gpu, X_gpu )
33+ end
3334
34- cpu_ms, cpu_all = _time_median ( ; samples = samples) do
35- exp (MP, p_cpu, X_cpu )
36- end
35+ function _benchmark_cpu_gpu (cpu_f, gpu_f ; samples:: Int )
36+ cpu_f ( )
37+ gpu_f ()
3738
38- gpu_ms, gpu_all = _time_median (; samples = samples) do
39- CUDA. @sync exp (MP, p_gpu, X_gpu)
40- end
39+ cpu_ms, cpu_all = _time_median (cpu_f; samples = samples)
40+ gpu_ms, gpu_all = _time_median (gpu_f; samples = samples)
4141
42+ return cpu_ms, cpu_all, gpu_ms, gpu_all
43+ end
44+
45+ function _print_results (; name:: String , n:: Int , k:: Int , batch:: Int , samples:: Int , cpu_all, gpu_all, cpu_ms:: Float64 , gpu_ms:: Float64 , relerr, relerr_label:: String , extra_lines:: Vector{String} = String[])
4246 speedup = cpu_ms / gpu_ms
43- relerr = begin
44- Y_cpu = exp (MP, p_cpu, X_cpu)
45- Y_gpu = Array (CUDA. @sync exp (MP, p_gpu, X_gpu))
46- norm (Y_cpu .- Y_gpu) / max (norm (Y_cpu), eps (Float32))
47- end
4847
49- println (" === ManifoldsGPU benchmark: exp on PowerManifold(Stiefel($n , $k ), $batch ) ===" )
48+ println (" === ManifoldsGPU benchmark: $name on PowerManifold(Stiefel($n , $k ), $batch ) ===" )
5049 println (" Element type: Float32" )
50+ for line in extra_lines
51+ println (line)
52+ end
5153 println (" Samples: $samples " )
5254 println (" CPU times [ms]: " , round .(cpu_all; digits = 2 ))
5355 println (" GPU times [ms]: " , round .(gpu_all; digits = 2 ))
5456 println (" Median CPU [ms]: " , round (cpu_ms; digits = 2 ))
5557 println (" Median GPU [ms]: " , round (gpu_ms; digits = 2 ))
5658 println (" Speedup (CPU/GPU): " , round (speedup; digits = 2 ), " x" )
57- return println (" Relative error ||Ycpu - Ygpu||/||Ycpu||: " , relerr)
59+ return println (" Relative error $relerr_label : " , relerr)
60+ end
61+
62+ function benchmark_stiefel_exp (; n:: Int = 32 , k:: Int = 16 , batch:: Int = 2048 , scale:: Float32 = 0.2f0 , samples:: Int = 6 , seed:: Int = 1234 )
63+ data = _setup_stiefel_data (; n = n, k = k, batch = batch, scale = scale, seed = seed)
64+ MP = data. MP
65+ p_cpu = data. p_cpu
66+ X_cpu = data. X_cpu
67+ p_gpu = data. p_gpu
68+ X_gpu = data. X_gpu
69+
70+ cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu (
71+ () -> exp (MP, p_cpu, X_cpu),
72+ () -> CUDA. @sync exp (MP, p_gpu, X_gpu);
73+ samples = samples,
74+ )
75+
76+ relerr = begin
77+ Y_cpu = exp (MP, p_cpu, X_cpu)
78+ Y_gpu = Array (CUDA. @sync exp (MP, p_gpu, X_gpu))
79+ norm (Y_cpu .- Y_gpu) / max (norm (Y_cpu), eps (Float32))
80+ end
81+
82+ return _print_results (
83+ name = " exp" ,
84+ n = n,
85+ k = k,
86+ batch = batch,
87+ samples = samples,
88+ cpu_all = cpu_all,
89+ gpu_all = gpu_all,
90+ cpu_ms = cpu_ms,
91+ gpu_ms = gpu_ms,
92+ relerr = relerr,
93+ relerr_label = " ||Ycpu - Ygpu||/||Ycpu||" ,
94+ )
95+ end
96+
97+ function benchmark_stiefel_retract_qr_fused (; n:: Int = 32 , k:: Int = 16 , batch:: Int = 2048 , scale:: Float32 = 0.2f0 , t:: Float32 = 0.3f0 , samples:: Int = 6 , seed:: Int = 1234 )
98+ data = _setup_stiefel_data (; n = n, k = k, batch = batch, scale = scale, seed = seed)
99+ MP = data. MP
100+ p_cpu = data. p_cpu
101+ X_cpu = data. X_cpu
102+ p_gpu = data. p_gpu
103+ X_gpu = data. X_gpu
104+
105+ q_cpu = similar (p_cpu)
106+ q_gpu = similar (p_gpu)
107+
108+ cpu_ms, cpu_all, gpu_ms, gpu_all = _benchmark_cpu_gpu (
109+ () -> ManifoldsBase. retract_fused! (MP, q_cpu, p_cpu, X_cpu, t, QRRetraction ()),
110+ () -> CUDA. @sync ManifoldsBase. retract_fused! (MP, q_gpu, p_gpu, X_gpu, t, QRRetraction ());
111+ samples = samples,
112+ )
113+
114+ relerr = begin
115+ ManifoldsBase. retract_fused! (MP, q_cpu, p_cpu, X_cpu, t, QRRetraction ())
116+ CUDA. @sync ManifoldsBase. retract_fused! (MP, q_gpu, p_gpu, X_gpu, t, QRRetraction ())
117+ q_gpu_h = Array (q_gpu)
118+ norm (q_cpu .- q_gpu_h) / max (norm (q_cpu), eps (Float32))
119+ end
120+
121+ return _print_results (
122+ name = " retract_qr_fused" ,
123+ n = n,
124+ k = k,
125+ batch = batch,
126+ samples = samples,
127+ cpu_all = cpu_all,
128+ gpu_all = gpu_all,
129+ cpu_ms = cpu_ms,
130+ gpu_ms = gpu_ms,
131+ relerr = relerr,
132+ relerr_label = " ||Qcpu - Qgpu||/||Qcpu||" ,
133+ extra_lines = [" Retraction scalar t: $t " ],
134+ )
58135end
59136
60137function _parse_arg (i:: Int , default)
@@ -68,8 +145,10 @@ function main()
68145 samples = _parse_arg (4 , 6 )
69146
70147 println (" Running with n=$n , k=$k , batch=$batch , samples=$samples " )
71-
72- return benchmark_stiefel_exp (; n = n, k = k, batch = batch, samples = samples)
148+ println ()
149+ benchmark_stiefel_exp (; n = n, k = k, batch = batch, samples = samples)
150+ println ()
151+ return benchmark_stiefel_retract_qr_fused (; n = n, k = k, batch = batch, samples = samples)
73152end
74153
75154main ()
0 commit comments