Skip to content

Commit d0f4ba1

Browse files
authored
Merge pull request #7 from zazabap/feat/grassmann-project
add GPU project! and retract_polar_fused! for Grassmann
2 parents 9f666c5 + 04a7ebb commit d0f4ba1

5 files changed

Lines changed: 243 additions & 1 deletion

File tree

benchmarks/Grassmann.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,18 @@ if abspath(PROGRAM_FILE) == @__FILE__
1111
k = _parse_arg(2, 16)
1212
batch = _parse_arg(3, 2048)
1313
samples = _parse_arg(4, 6)
14+
scale = 0.2f0
15+
t = 0.3f0
1416

1517
println("Running Grassmann benchmarks: n=$n, k=$k, batch=$batch, samples=$samples")
1618
println()
1719

18-
results = benchmark_manifold("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = 0.2f0, samples = samples, seed = 1234, point_type = Float32, exp_error_fn = _subspace_error)
20+
results = benchmark_manifold("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, samples = samples, seed = 1234, point_type = Float32, exp_error_fn = _subspace_error)
21+
22+
data = _setup_data(Grassmann(n, k); batch = batch, scale = scale, seed = 1234, point_type = Float32)
23+
manifold_label = "PowerManifold(Grassmann($n, $k), $batch)"
24+
push!(results, _benchmark_retraction(PolarRetraction(); MP = data.MB, p_cpu = data.p_cpu, X_cpu = data.X_cpu, p_gpu = data.p_gpu, X_gpu = data.X_gpu, t = t, samples = samples, manifold_label = manifold_label))
25+
println()
1926

2027
println(generate_markdown_summary_table(results))
2128
end

benchmarks/main.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ function main()
5050

5151
append!(all_results, benchmark_manifold("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, samples = samples, seed = seed + 4, point_type = Float32, exp_error_fn = _subspace_error))
5252

53+
append!(all_results, _benchmark_extra_retractions("Grassmann($n, $k)", Grassmann(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 4, point_type = Float32, methods = [PolarRetraction()]))
54+
5355
append!(all_results, _benchmark_extra_retractions("Stiefel($n, $k)", Stiefel(n, k); batch = batch, scale = scale, t = t, samples = samples, seed = seed + 5, point_type = Float32, methods = [ExponentialRetraction(), PolarRetraction()]))
5456

5557
markdown_table = generate_markdown_summary_table(all_results)

ext/ManifoldsGPUCUDAExt/Grassmann.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,45 @@ function ManifoldsBase.exp!(
4040

4141
return q
4242
end
43+
44+
function ManifoldsBase.project!(
45+
::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
46+
Y::CuArray{T, 3},
47+
p::CuArray{T, 3},
48+
X::CuArray{T, 3},
49+
) where {T <: Real}
50+
A = CUDA.CUBLAS.gemm_strided_batched('C', 'N', p, X) # p' * X
51+
Y .= X .- CUDA.CUBLAS.gemm_strided_batched('N', 'N', p, A) # X - p * (p' * X)
52+
return Y
53+
end
54+
55+
function ManifoldsBase.project!(
56+
::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
57+
q::CuArray{T, 3},
58+
p::CuArray{T, 3},
59+
) where {T <: Real}
60+
q .= p
61+
return _polar_project_gpu!(q)
62+
end
63+
64+
function ManifoldsBase.retract_polar_fused!(
65+
::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
66+
q::CuArray{T, 3},
67+
p::CuArray{T, 3},
68+
X::CuArray{T, 3},
69+
t::Number,
70+
) where {T <: Real}
71+
q .= p .+ t .* X
72+
return _polar_project_gpu!(q)
73+
end
74+
75+
function ManifoldsBase.retract_fused!(
76+
M::PowerManifold{ℝ, <:Grassmann{ℝ}, <:Tuple, ArrayPowerRepresentation},
77+
q::CuArray{T, 3},
78+
p::CuArray{T, 3},
79+
X::CuArray{T, 3},
80+
t::Number,
81+
::PolarRetraction,
82+
) where {T <: Real}
83+
return ManifoldsBase.retract_polar_fused!(M, q, p, X, t)
84+
end

test/cuda/test_grassmann.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,150 @@
112112
@test distance(MP, Y_cu_h, Y) < 2.0e-12
113113
end
114114
end
115+
116+
@testset "project! tangent Float64" begin
117+
Random.seed!(83)
118+
119+
M = Grassmann(8, 4)
120+
MP = PowerManifold(M, 32)
121+
122+
p = rand(MP)
123+
X = randn(size(p)...)
124+
125+
Y_cpu = similar(X)
126+
for i in 1:size(p, 3)
127+
ManifoldsBase.project!(
128+
M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i)
129+
)
130+
end
131+
132+
p_cu = CuArray(p)
133+
X_cu = CuArray(X)
134+
Y_cu = similar(X_cu)
135+
ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu)
136+
Y_cu_h = Array(Y_cu)
137+
138+
@test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14)
139+
end
140+
141+
@testset "project! tangent Float32" begin
142+
Random.seed!(84)
143+
144+
M = Grassmann(8, 4)
145+
MP = PowerManifold(M, 32)
146+
147+
p = Float32.(rand(MP))
148+
X = Float32.(randn(size(p)...))
149+
150+
Y_cpu = similar(X)
151+
for i in 1:size(p, 3)
152+
ManifoldsBase.project!(
153+
M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i)
154+
)
155+
end
156+
157+
p_cu = CuArray(p)
158+
X_cu = CuArray(X)
159+
Y_cu = similar(X_cu)
160+
ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu)
161+
Y_cu_h = Array(Y_cu)
162+
163+
@test isapprox(Y_cu_h, Y_cpu; atol = 2.0f-5, rtol = 2.0f-5)
164+
end
165+
166+
@testset "project! point Float64" begin
167+
Random.seed!(85)
168+
169+
M = Grassmann(8, 4)
170+
MP = PowerManifold(M, 32)
171+
172+
p = rand(MP)
173+
p_noisy = p .+ 0.01 .* randn(size(p)...)
174+
175+
q_cpu = similar(p)
176+
for i in 1:size(p, 3)
177+
ManifoldsBase.project!(
178+
M, view(q_cpu, :, :, i), view(p_noisy, :, :, i)
179+
)
180+
end
181+
182+
p_noisy_cu = CuArray(p_noisy)
183+
q_cu = similar(p_noisy_cu)
184+
ManifoldsBase.project!(MP, q_cu, p_noisy_cu)
185+
q_cu_h = Array(q_cu)
186+
187+
@test is_point(MP, q_cu_h)
188+
@test isapprox(q_cu_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14)
189+
end
190+
191+
@testset "project! point Float32" begin
192+
Random.seed!(86)
193+
194+
M = Grassmann(8, 4)
195+
MP = PowerManifold(M, 32)
196+
197+
p = Float32.(rand(MP))
198+
p_noisy = p .+ Float32(0.01) .* Float32.(randn(size(p)...))
199+
200+
q_cpu = similar(p)
201+
for i in 1:size(p, 3)
202+
ManifoldsBase.project!(
203+
M, view(q_cpu, :, :, i), view(p_noisy, :, :, i)
204+
)
205+
end
206+
207+
p_noisy_cu = CuArray(p_noisy)
208+
q_cu = similar(p_noisy_cu)
209+
ManifoldsBase.project!(MP, q_cu, p_noisy_cu)
210+
q_cu_h = Array(q_cu)
211+
212+
@test is_point(MP, q_cu_h)
213+
@test isapprox(q_cu_h, q_cpu; atol = 2.0f-5, rtol = 2.0f-5)
214+
end
215+
216+
@testset "retract_polar_fused Float64" begin
217+
Random.seed!(87)
218+
219+
M = Grassmann(8, 4)
220+
MP = PowerManifold(M, 32)
221+
t = 0.3
222+
223+
p = rand(MP)
224+
X = rand(MP; vector_at = p)
225+
226+
q = similar(p)
227+
ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction())
228+
229+
p_cu = CuArray(p)
230+
X_cu = CuArray(X)
231+
q_cu = similar(p_cu)
232+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction())
233+
q_cu_h = Array(q_cu)
234+
235+
@test is_point(MP, q_cu_h)
236+
@test isapprox(q_cu_h, q; atol = 2.0e-14, rtol = 2.0e-14)
237+
end
238+
239+
@testset "retract_polar_fused Float32" begin
240+
Random.seed!(88)
241+
242+
M = Grassmann(8, 4)
243+
MP = PowerManifold(M, 32)
244+
t = Float32(0.3)
245+
246+
p = Float32.(rand(MP))
247+
X = Float32.(rand(MP; vector_at = p))
248+
249+
q = similar(p)
250+
ManifoldsBase.retract_fused!(MP, q, p, X, t, PolarRetraction())
251+
252+
p_cu = CuArray(p)
253+
X_cu = CuArray(X)
254+
q_cu = similar(p_cu)
255+
ManifoldsBase.retract_fused!(MP, q_cu, p_cu, X_cu, t, PolarRetraction())
256+
q_cu_h = Array(q_cu)
257+
258+
@test is_point(MP, q_cu_h)
259+
@test isapprox(q_cu_h, q; atol = 2.0f-5, rtol = 2.0f-5)
260+
end
115261
end

test/jlarray/test_general_unitary_matrices.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,49 @@ using GPUArrays
6969
@test isapprox(q_jl_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14)
7070
end
7171
end
72+
73+
# Grassmann project! tangent (Float64)
74+
@testset "Grassmann project tangent Float64" begin
75+
Random.seed!(45)
76+
77+
M = Grassmann(6, 3)
78+
79+
for _ in 1:3
80+
p = rand(M)
81+
X = randn(6, 3)
82+
Y_cpu = similar(X)
83+
project!(M, Y_cpu, p, X)
84+
85+
p_jl = JLArray(p)
86+
X_jl = JLArray(X)
87+
Y_jl = similar(X_jl)
88+
project!(M, Y_jl, p_jl, X_jl)
89+
Y_jl_h = Array(Y_jl)
90+
91+
@test is_vector(M, p, Y_jl_h; atol = 2.0e-14)
92+
@test isapprox(Y_jl_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14)
93+
end
94+
end
95+
96+
# Grassmann project! point (Float64)
97+
@testset "Grassmann project point Float64" begin
98+
Random.seed!(46)
99+
100+
M = Grassmann(6, 3)
101+
102+
for _ in 1:3
103+
p = rand(M)
104+
p_noisy = p .+ 0.01 .* randn(size(p)...)
105+
q_cpu = similar(p)
106+
project!(M, q_cpu, p_noisy)
107+
108+
p_noisy_jl = JLArray(p_noisy)
109+
q_jl = similar(p_noisy_jl)
110+
project!(M, q_jl, p_noisy_jl)
111+
q_jl_h = Array(q_jl)
112+
113+
@test is_point(M, q_jl_h)
114+
@test isapprox(q_jl_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14)
115+
end
116+
end
72117
end

0 commit comments

Comments
 (0)