Skip to content

Commit 0094316

Browse files
committed
also make sphere less slow
1 parent ffbd5b2 commit 0094316

3 files changed

Lines changed: 191 additions & 32 deletions

File tree

docs/src/index.md

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,38 @@ ManifoldsGPU.ManifoldsGPU
88

99
=== Markdown summary table ===
1010
Device: NVIDIA GeForce RTX 5070 Ti, eltype: Float32/ComplexF32
11+
1112
| Manifold | Operation | CPU median [ms] | GPU median [ms] | Speedup CPU/GPU | Error |
1213
| --- | --- | ---: | ---: | ---: | ---: |
13-
| Euclidean(32, 16, 2048) | exp | 0.34 | 0.17 | 2.07 | 0.0 |
14-
| Euclidean(32, 16, 2048) | log! | 0.34 | 0.17 | 2.03 | 0.0 |
15-
| Euclidean(32, 16, 2048) | inner | 0.19 | 0.13 | 1.42 | 0.0 |
16-
| Euclidean(32, 16, 2048) | norm | 0.12 | 0.16 | 0.74 | 8.423e-8 |
17-
| Euclidean(32, 16, 2048) | project! | 0.23 | 0.13 | 1.76 | 0.0 |
18-
| PowerManifold(Sphere(31), 2048) | exp | 0.05 | 36.49 | 0.0 | 6.877e-8 |
19-
| PowerManifold(Sphere(31), 2048) | log! | 0.09 | 68.95 | 0.0 | 4.262e-8 |
20-
| PowerManifold(Sphere(31), 2048) | inner | 0.02 | 0.12 | 0.12 | 6.837e-7 |
21-
| PowerManifold(Sphere(31), 2048) | norm | 0.02 | 0.13 | 0.15 | 0.0 |
22-
| PowerManifold(Sphere(31), 2048) | project! | 0.03 | 34.91 | 0.0 | 2.813e-8 |
23-
| PowerManifold(Rotations(32), 2048) | exp | 36.83 | 2.35 | 15.68 | 2.594e-6 |
24-
| PowerManifold(Rotations(32), 2048) | log! | 561.36 | 67.61 | 8.3 | 9.157e-5 |
25-
| PowerManifold(Rotations(32), 2048) | inner | 0.41 | 0.14 | 2.85 | 4.508e-6 |
26-
| PowerManifold(Rotations(32), 2048) | norm | 1.32 | 0.14 | 9.64 | 1.109e-6 |
27-
| PowerManifold(Rotations(32), 2048) | project! | 20.33 | 0.22 | 93.02 | 3.644e-7 |
28-
| PowerManifold(Rotations(32), 2048) | retract_fused!(PolarRetraction) | 116.77 | 4.82 | 24.24 | 2.555e-6 |
29-
| PowerManifold(Rotations(32), 2048) | retract_fused!(QRRetraction) | 90.92 | 0.85 | 106.99 | 3.204e-7 |
30-
| PowerManifold(UnitaryMatrices(32), 2048) | exp | 86.44 | 7.65 | 11.3 | 1.957e-6 |
31-
| PowerManifold(UnitaryMatrices(32), 2048) | log! | 728.19 | 70.03 | 10.4 | 0.0001844 |
32-
| PowerManifold(UnitaryMatrices(32), 2048) | inner | 0.78 | 52.19 | 0.02 | 5.979e-5 |
33-
| PowerManifold(UnitaryMatrices(32), 2048) | norm | 1.76 | 41.09 | 0.04 | 1.516e-6 |
34-
| PowerManifold(UnitaryMatrices(32), 2048) | project! | 31.94 | 0.36 | 89.84 | 5.512e-7 |
35-
| PowerManifold(Grassmann(32, 16), 2048) | exp | 69.47 | 5.23 | 13.27 | 7.023e-5 |
36-
| PowerManifold(Grassmann(32, 16), 2048) | log! | 57.7 | 3.42 | 16.85 | 2.332e-5 |
37-
| PowerManifold(Grassmann(32, 16), 2048) | inner | 0.2 | 0.12 | 1.58 | 4.957e-7 |
38-
| PowerManifold(Grassmann(32, 16), 2048) | norm | 0.79 | 0.13 | 6.18 | 2.772e-7 |
39-
| PowerManifold(Grassmann(32, 16), 2048) | project! | 1.03 | 0.19 | 5.31 | 1.303e-7 |
40-
| PowerManifold(Grassmann(32, 16), 2048) | retract_fused!(PolarRetraction) | 41.38 | 2.75 | 15.06 | 0.0001873 |
41-
| PowerManifold(Grassmann(32, 16), 2048) | retract_fused!(QRRetraction) | 17.85 | 0.7 | 25.33 | 4.623e-5 |
42-
| PowerManifold(Stiefel(32, 16), 2048) | exp(ExponentialRetraction) | 71.95 | 3.65 | 19.71 | 1.164e-6 |
43-
| PowerManifold(Stiefel(32, 16), 2048) | retract_fused!(PolarRetraction) | 43.03 | 3.06 | 14.07 | 1.37e-6 |
44-
| PowerManifold(Stiefel(32, 16), 2048) | retract_fused!(QRRetraction) | 18.38 | 0.71 | 25.88 | 1.885e-7 |
14+
| Euclidean(32, 16, 2048) | exp | 0.35 | 0.17 | 2.06 | 0.0 |
15+
| Euclidean(32, 16, 2048) | log! | 0.35 | 0.17 | 2.05 | 0.0 |
16+
| Euclidean(32, 16, 2048) | inner | 0.19 | 0.14 | 1.34 | 9.357e-8 |
17+
| Euclidean(32, 16, 2048) | norm | 0.13 | 0.16 | 0.83 | 8.423e-8 |
18+
| Euclidean(32, 16, 2048) | project! | 0.23 | 0.12 | 1.89 | 0.0 |
19+
| PowerManifold(Sphere(31), 2048) | exp | 0.05 | 0.14 | 0.35 | 7.092e-8 |
20+
| PowerManifold(Sphere(31), 2048) | log! | 0.08 | 0.37 | 0.23 | 5.125e-8 |
21+
| PowerManifold(Sphere(31), 2048) | inner | 0.02 | 0.13 | 0.14 | 5.86e-7 |
22+
| PowerManifold(Sphere(31), 2048) | norm | 0.02 | 0.13 | 0.15 | 1.064e-7 |
23+
| PowerManifold(Sphere(31), 2048) | project! | 0.03 | 0.15 | 0.18 | 2.819e-8 |
24+
| PowerManifold(Rotations(32), 2048) | exp | 36.19 | 2.35 | 15.38 | 2.594e-6 |
25+
| PowerManifold(Rotations(32), 2048) | log! | 565.72 | 74.44 | 7.6 | 9.157e-5 |
26+
| PowerManifold(Rotations(32), 2048) | inner | 0.41 | 0.25 | 1.65 | 4.708e-6 |
27+
| PowerManifold(Rotations(32), 2048) | norm | 1.36 | 0.14 | 9.53 | 1.109e-6 |
28+
| PowerManifold(Rotations(32), 2048) | project! | 20.46 | 0.22 | 91.22 | 3.644e-7 |
29+
| PowerManifold(Rotations(32), 2048) | retract_fused!(PolarRetraction) | 115.51 | 4.89 | 23.62 | 2.555e-6 |
30+
| PowerManifold(Rotations(32), 2048) | retract_fused!(QRRetraction) | 90.11 | 0.83 | 108.95 | 3.204e-7 |
31+
| PowerManifold(UnitaryMatrices(32), 2048) | exp | 85.89 | 8.07 | 10.64 | 1.957e-6 |
32+
| PowerManifold(UnitaryMatrices(32), 2048) | log! | 729.57 | 69.89 | 10.44 | 0.0001844 |
33+
| PowerManifold(UnitaryMatrices(32), 2048) | inner | 0.83 | 56.26 | 0.01 | 5.979e-5 |
34+
| PowerManifold(UnitaryMatrices(32), 2048) | norm | 1.74 | 44.17 | 0.04 | 1.516e-6 |
35+
| PowerManifold(UnitaryMatrices(32), 2048) | project! | 31.42 | 0.35 | 90.36 | 5.512e-7 |
36+
| PowerManifold(Grassmann(32, 16), 2048) | exp | 69.72 | 5.27 | 13.22 | 7.023e-5 |
37+
| PowerManifold(Grassmann(32, 16), 2048) | log! | 57.99 | 3.36 | 17.27 | 2.332e-5 |
38+
| PowerManifold(Grassmann(32, 16), 2048) | inner | 0.2 | 0.13 | 1.54 | 8.056e-7 |
39+
| PowerManifold(Grassmann(32, 16), 2048) | norm | 0.81 | 0.14 | 5.95 | 3.696e-7 |
40+
| PowerManifold(Grassmann(32, 16), 2048) | project! | 1.0 | 0.23 | 4.45 | 1.303e-7 |
41+
| PowerManifold(Grassmann(32, 16), 2048) | retract_fused!(PolarRetraction) | 40.67 | 2.99 | 13.59 | 0.0001873 |
42+
| PowerManifold(Grassmann(32, 16), 2048) | retract_fused!(QRRetraction) | 17.84 | 0.72 | 24.63 | 4.623e-5 |
43+
| PowerManifold(Stiefel(32, 16), 2048) | exp(ExponentialRetraction) | 70.99 | 3.56 | 19.91 | 1.164e-6 |
44+
| PowerManifold(Stiefel(32, 16), 2048) | retract_fused!(PolarRetraction) | 43.66 | 2.87 | 15.21 | 1.37e-6 |
45+
| PowerManifold(Stiefel(32, 16), 2048) | retract_fused!(QRRetraction) | 18.28 | 0.71 | 25.89 | 1.885e-7 |

ext/ManifoldsGPUCUDAExt/Sphere.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,61 @@ function ManifoldsBase.norm(
1414
) where {T <: Real}
1515
return sqrt(dot(X, X))
1616
end
17+
18+
function ManifoldsBase.exp!(
19+
::PowerManifold{ℝ, <:Sphere{ℝ}, <:Tuple, ArrayPowerRepresentation},
20+
q::CuArray{T},
21+
p::CuArray{T},
22+
X::CuArray{T},
23+
) where {T <: Real}
24+
θ = sqrt.(sum(abs2, X; dims = 1))
25+
q .= cos.(θ) .* p .+ Manifolds.usinc.(θ) .* X
26+
return q
27+
end
28+
29+
function ManifoldsBase.log!(
30+
M::PowerManifold{ℝ, <:Sphere{ℝ}, <:Tuple, ArrayPowerRepresentation},
31+
X::CuArray{T},
32+
p::CuArray{T},
33+
q::CuArray{T},
34+
) where {T <: Real}
35+
cosθ = clamp.(sum(p .* q; dims = 1), -one(T), one(T))
36+
θ = acos.(cosθ)
37+
38+
X_regular = (q .- cosθ .* p) ./ Manifolds.usinc.(θ)
39+
40+
antipodal = abs.(cosθ .+ one(T)) .<= sqrt(eps(T))
41+
basis = CUDA.zeros(T, size(p))
42+
basis[1, :] .= one(T)
43+
if size(p, 1) > 1
44+
p1_is_one = abs.(p[1:1, :] .- one(T)) .<= sqrt(eps(T))
45+
basis[1, :] .-= T.(p1_is_one[1, :])
46+
basis[2, :] .= T.(p1_is_one[1, :])
47+
end
48+
49+
X_antipodal = basis .- p .* sum(p .* basis; dims = 1)
50+
X_antipodal .*= T(π) ./ sqrt.(sum(abs2, X_antipodal; dims = 1))
51+
52+
X .= ifelse.(antipodal, X_antipodal, X_regular)
53+
return project!(M, X, p, X)
54+
end
55+
56+
function ManifoldsBase.project!(
57+
::PowerManifold{ℝ, <:Sphere{ℝ}, <:Tuple, ArrayPowerRepresentation},
58+
Y::CuArray{T},
59+
p::CuArray{T},
60+
X::CuArray{T},
61+
) where {T <: Real}
62+
Y .= X .- p .* sum(p .* X; dims = 1)
63+
return Y
64+
end
65+
66+
function ManifoldsBase.project!(
67+
::PowerManifold{ℝ, <:Sphere{ℝ}, <:Tuple, ArrayPowerRepresentation},
68+
q::CuArray{T},
69+
p::CuArray{T},
70+
) where {T <: Real}
71+
norms_p = sqrt.(sum(abs2, p; dims = 1))
72+
q .= p ./ norms_p
73+
return q
74+
end

test/cuda/test_sphere.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Manifolds, ManifoldsGPU, Test, Random, CUDA
2+
13
@testset "Sphere CUDA" begin
24
@testset "inner and norm" begin
35
Random.seed!(70)
@@ -46,4 +48,102 @@
4648
@test isapprox(i_gpu, i_cpu; atol = 1.0f-4, rtol = 1.0f-4)
4749
@test isapprox(n_gpu, n_cpu; atol = 1.0f-4, rtol = 1.0f-4)
4850
end
51+
52+
@testset "project!" begin
53+
for T in [Float32, Float64]
54+
Random.seed!(72)
55+
56+
M = Sphere(7)
57+
MP = PowerManifold(M, 32)
58+
59+
p = T.(rand(MP))
60+
q_cpu = similar(p)
61+
project!(MP, q_cpu, p)
62+
63+
p_cu = CuArray(p)
64+
q_cu = similar(p_cu)
65+
project!(MP, q_cu, p_cu)
66+
67+
if T === Float32
68+
@test isapprox(Array(q_cu), q_cpu; atol = 1.0f-5, rtol = 1.0f-5)
69+
else
70+
@test isapprox(Array(q_cu), q_cpu; atol = 1.0e-12, rtol = 1.0e-12)
71+
end
72+
end
73+
end
74+
75+
@testset "project! tangent" begin
76+
for (seed, T, atol, rtol) in (
77+
(721, Float32, 1.0f-5, 1.0f-5),
78+
(722, Float64, 1.0e-12, 1.0e-12),
79+
)
80+
Random.seed!(seed)
81+
82+
M = Sphere(7)
83+
MP = PowerManifold(M, 32)
84+
85+
p = T.(rand(MP))
86+
X = T.(randn(size(p)))
87+
88+
Y_cpu = similar(p)
89+
project!(MP, Y_cpu, p, X)
90+
91+
p_cu = CuArray(p)
92+
X_cu = CuArray(X)
93+
Y_cu = similar(p_cu)
94+
project!(MP, Y_cu, p_cu, X_cu)
95+
96+
@test isapprox(Array(Y_cu), Y_cpu; atol = atol, rtol = rtol)
97+
end
98+
end
99+
100+
@testset "exp!" begin
101+
for (seed, T, atol, rtol) in (
102+
(73, Float32, 1.0f-4, 1.0f-4),
103+
(74, Float64, 1.0e-10, 1.0e-10),
104+
)
105+
Random.seed!(seed)
106+
107+
M = Sphere(7)
108+
MP = PowerManifold(M, 32)
109+
110+
p = T.(rand(MP))
111+
X = T.(rand(MP; vector_at = p))
112+
113+
q_cpu = similar(p)
114+
exp!(MP, q_cpu, p, X)
115+
116+
p_cu = CuArray(p)
117+
X_cu = CuArray(X)
118+
q_cu = similar(p_cu)
119+
exp!(MP, q_cu, p_cu, X_cu)
120+
121+
@test isapprox(Array(q_cu), q_cpu; atol = atol, rtol = rtol)
122+
end
123+
end
124+
125+
@testset "log!" begin
126+
for (seed, T, atol, rtol) in (
127+
(75, Float32, 5.0f-4, 5.0f-4),
128+
(76, Float64, 1.0e-9, 1.0e-9),
129+
)
130+
Random.seed!(seed)
131+
132+
M = Sphere(7)
133+
MP = PowerManifold(M, 32)
134+
135+
p = T.(rand(MP))
136+
q = T.(rand(MP))
137+
138+
X_cpu = similar(p)
139+
log!(MP, X_cpu, p, q)
140+
141+
p_cu = CuArray(p)
142+
q_cu = CuArray(q)
143+
X_cu = similar(p_cu)
144+
log!(MP, X_cu, p_cu, q_cu)
145+
146+
@test isapprox(Array(X_cu), X_cpu; atol = atol, rtol = rtol)
147+
end
148+
end
49149
end

0 commit comments

Comments
 (0)