Skip to content

Commit 005ec5c

Browse files
committed
add GPU project! (point + tangent) for Stiefel
1 parent dfabf06 commit 005ec5c

2 files changed

Lines changed: 120 additions & 0 deletions

File tree

ext/ManifoldsGPUCUDAExt/Stiefel.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,28 @@ function ManifoldsBase.norm(
4848
return sqrt(dot(X, X))
4949
end
5050

51+
function ManifoldsBase.project!(
52+
::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
53+
q::CuArray{T, 3},
54+
p::CuArray{T, 3},
55+
) where {T <: Real}
56+
q .= p
57+
return _polar_project_gpu!(q)
58+
end
59+
60+
function ManifoldsBase.project!(
61+
::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
62+
Y::CuArray{T, 3},
63+
p::CuArray{T, 3},
64+
X::CuArray{T, 3},
65+
) where {T <: Real}
66+
A = CUDA.CUBLAS.gemm_strided_batched('C', 'N', p, X) # p'X, k×k×batch
67+
sym = A .+ permutedims(A, (2, 1, 3)) # A + A'
68+
Y .= X
69+
CUDA.CUBLAS.gemm_strided_batched!('N', 'N', T(-0.5), p, sym, one(T), Y)
70+
return Y
71+
end
72+
5173
function ManifoldsBase.retract_polar_fused!(
5274
::PowerManifold{ℝ, <:Stiefel{ℝ}, <:Tuple, ArrayPowerRepresentation},
5375
q::CuArray{T, 3},

test/cuda/test_stiefel.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,54 @@
150150
@test isapprox(n_gpu, n_cpu; atol = 1.0e-10, rtol = 1.0e-10)
151151
end
152152

153+
@testset "project! point Float64" begin
154+
Random.seed!(76)
155+
156+
M = Stiefel(8, 4)
157+
MP = PowerManifold(M, 32)
158+
159+
p = randn(size(rand(MP))...)
160+
161+
q_cpu = similar(p)
162+
for i in 1:size(p, 3)
163+
ManifoldsBase.project!(
164+
M, view(q_cpu, :, :, i), view(p, :, :, i)
165+
)
166+
end
167+
168+
p_cu = CuArray(p)
169+
q_cu = similar(p_cu)
170+
ManifoldsBase.project!(MP, q_cu, p_cu)
171+
q_cu_h = Array(q_cu)
172+
173+
@test is_point(MP, q_cu_h)
174+
@test isapprox(q_cu_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14)
175+
end
176+
177+
@testset "project! point Float32" begin
178+
Random.seed!(77)
179+
180+
M = Stiefel(8, 4)
181+
MP = PowerManifold(M, 32)
182+
183+
p = Float32.(randn(size(rand(MP))...))
184+
185+
q_cpu = similar(p)
186+
for i in 1:size(p, 3)
187+
ManifoldsBase.project!(
188+
M, view(q_cpu, :, :, i), view(p, :, :, i)
189+
)
190+
end
191+
192+
p_cu = CuArray(p)
193+
q_cu = similar(p_cu)
194+
ManifoldsBase.project!(MP, q_cu, p_cu)
195+
q_cu_h = Array(q_cu)
196+
197+
@test is_point(MP, q_cu_h)
198+
@test isapprox(q_cu_h, q_cpu; atol = 2.0f-5, rtol = 2.0f-5)
199+
end
200+
153201
@testset "inner and norm Float32" begin
154202
Random.seed!(75)
155203

@@ -173,4 +221,54 @@
173221
@test isapprox(i_gpu, i_cpu; atol = 1.0f-4, rtol = 1.0f-4)
174222
@test isapprox(n_gpu, n_cpu; atol = 1.0f-4, rtol = 1.0f-4)
175223
end
224+
225+
@testset "project! tangent Float64" begin
226+
Random.seed!(78)
227+
228+
M = Stiefel(8, 4)
229+
MP = PowerManifold(M, 32)
230+
231+
p = rand(MP)
232+
X = randn(size(p)...)
233+
234+
Y_cpu = similar(X)
235+
for i in 1:size(p, 3)
236+
ManifoldsBase.project!(
237+
M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i)
238+
)
239+
end
240+
241+
p_cu = CuArray(p)
242+
X_cu = CuArray(X)
243+
Y_cu = similar(X_cu)
244+
ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu)
245+
Y_cu_h = Array(Y_cu)
246+
247+
@test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14)
248+
end
249+
250+
@testset "project! tangent Float32" begin
251+
Random.seed!(79)
252+
253+
M = Stiefel(8, 4)
254+
MP = PowerManifold(M, 32)
255+
256+
p = Float32.(rand(MP))
257+
X = Float32.(randn(size(p)...))
258+
259+
Y_cpu = similar(X)
260+
for i in 1:size(p, 3)
261+
ManifoldsBase.project!(
262+
M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i)
263+
)
264+
end
265+
266+
p_cu = CuArray(p)
267+
X_cu = CuArray(X)
268+
Y_cu = similar(X_cu)
269+
ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu)
270+
Y_cu_h = Array(Y_cu)
271+
272+
@test isapprox(Y_cu_h, Y_cpu; atol = 2.0f-5, rtol = 2.0f-5)
273+
end
176274
end

0 commit comments

Comments
 (0)