77
88include (joinpath (@__DIR__ , " utils.jl" ))
99
10- function _benchmark_extra_retractions (name:: String , M; batch:: Int , scale:: Float32 , t:: Float32 , samples:: Int , seed:: Int , point_type, methods)
10+ function _benchmark_extra_retractions (name:: String , M; batch:: Int , scale:: Float32 , t:: Float32 , samples:: Int , seed:: Int , point_type, methods, error_fn = nothing )
1111 data = _setup_data (M; batch = batch, scale = scale, seed = seed, point_type = point_type, use_power_manifold = true )
1212 manifold_label = " PowerManifold($name , $batch )"
1313 results = NamedTuple[]
@@ -17,7 +17,7 @@ function _benchmark_extra_retractions(name::String, M; batch::Int, scale::Float3
1717 println ()
1818
1919 for method in methods
20- push! (results, _benchmark_retraction (method; MP = data. MB, p_cpu = data. p_cpu, X_cpu = data. X_cpu, p_gpu = data. p_gpu, X_gpu = data. X_gpu, t = t, samples = samples, manifold_label = manifold_label))
20+ push! (results, _benchmark_retraction (method; MP = data. MB, p_cpu = data. p_cpu, X_cpu = data. X_cpu, p_gpu = data. p_gpu, X_gpu = data. X_gpu, t = t, samples = samples, manifold_label = manifold_label, error_fn = error_fn ))
2121 println ()
2222 end
2323
@@ -44,15 +44,15 @@ function main()
4444
4545 append! (all_results, benchmark_manifold (" Rotations($n )" , Rotations (n); batch = batch, scale = scale, samples = samples, seed = seed + 2 , point_type = Float32))
4646
47- append! (all_results, _benchmark_extra_retractions (" Rotations($n )" , Rotations (n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2 , point_type = Float32, methods = [PolarRetraction ()]))
47+ append! (all_results, _benchmark_extra_retractions (" Rotations($n )" , Rotations (n); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 2 , point_type = Float32, methods = [PolarRetraction (), QRRetraction () ]))
4848
4949 append! (all_results, benchmark_manifold (" UnitaryMatrices($n )" , UnitaryMatrices (n); batch = batch, scale = scale, samples = samples, seed = seed + 3 , point_type = ComplexF32))
5050
5151 append! (all_results, benchmark_manifold (" Grassmann($n , $k )" , Grassmann (n, k); batch = batch, scale = scale, samples = samples, seed = seed + 4 , point_type = Float32, exp_error_fn = _subspace_error))
5252
53- append! (all_results, _benchmark_extra_retractions (" Grassmann($n , $k )" , Grassmann (n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4 , point_type = Float32, methods = [PolarRetraction ()] ))
53+ append! (all_results, _benchmark_extra_retractions (" Grassmann($n , $k )" , Grassmann (n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4 , point_type = Float32, methods = [PolarRetraction (), QRRetraction ()], error_fn = _subspace_error ))
5454
55- append! (all_results, _benchmark_extra_retractions (" Stiefel($n , $k )" , Stiefel (n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5 , point_type = Float32, methods = [ExponentialRetraction (), PolarRetraction ()]))
55+ append! (all_results, _benchmark_extra_retractions (" Stiefel($n , $k )" , Stiefel (n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5 , point_type = Float32, methods = [ExponentialRetraction (), PolarRetraction (), QRRetraction () ]))
5656
5757 markdown_table = generate_markdown_summary_table (all_results)
5858 println (" === Markdown summary table ===" )
0 commit comments