|
150 | 150 | @test isapprox(n_gpu, n_cpu; atol = 1.0e-10, rtol = 1.0e-10) |
151 | 151 | end |
152 | 152 |
|
| 153 | + @testset "project! point Float64" begin |
| 154 | + Random.seed!(76) |
| 155 | + |
| 156 | + M = Stiefel(8, 4) |
| 157 | + MP = PowerManifold(M, 32) |
| 158 | + |
| 159 | + p = randn(size(rand(MP))...) |
| 160 | + |
| 161 | + q_cpu = similar(p) |
| 162 | + for i in 1:size(p, 3) |
| 163 | + ManifoldsBase.project!( |
| 164 | + M, view(q_cpu, :, :, i), view(p, :, :, i) |
| 165 | + ) |
| 166 | + end |
| 167 | + |
| 168 | + p_cu = CuArray(p) |
| 169 | + q_cu = similar(p_cu) |
| 170 | + ManifoldsBase.project!(MP, q_cu, p_cu) |
| 171 | + q_cu_h = Array(q_cu) |
| 172 | + |
| 173 | + @test is_point(MP, q_cu_h) |
| 174 | + @test isapprox(q_cu_h, q_cpu; atol = 2.0e-14, rtol = 2.0e-14) |
| 175 | + end |
| 176 | + |
| 177 | + @testset "project! point Float32" begin |
| 178 | + Random.seed!(77) |
| 179 | + |
| 180 | + M = Stiefel(8, 4) |
| 181 | + MP = PowerManifold(M, 32) |
| 182 | + |
| 183 | + p = Float32.(randn(size(rand(MP))...)) |
| 184 | + |
| 185 | + q_cpu = similar(p) |
| 186 | + for i in 1:size(p, 3) |
| 187 | + ManifoldsBase.project!( |
| 188 | + M, view(q_cpu, :, :, i), view(p, :, :, i) |
| 189 | + ) |
| 190 | + end |
| 191 | + |
| 192 | + p_cu = CuArray(p) |
| 193 | + q_cu = similar(p_cu) |
| 194 | + ManifoldsBase.project!(MP, q_cu, p_cu) |
| 195 | + q_cu_h = Array(q_cu) |
| 196 | + |
| 197 | + @test is_point(MP, q_cu_h) |
| 198 | + @test isapprox(q_cu_h, q_cpu; atol = 2.0f-5, rtol = 2.0f-5) |
| 199 | + end |
| 200 | + |
153 | 201 | @testset "inner and norm Float32" begin |
154 | 202 | Random.seed!(75) |
155 | 203 |
|
|
173 | 221 | @test isapprox(i_gpu, i_cpu; atol = 1.0f-4, rtol = 1.0f-4) |
174 | 222 | @test isapprox(n_gpu, n_cpu; atol = 1.0f-4, rtol = 1.0f-4) |
175 | 223 | end |
| 224 | + |
| 225 | + @testset "project! tangent Float64" begin |
| 226 | + Random.seed!(78) |
| 227 | + |
| 228 | + M = Stiefel(8, 4) |
| 229 | + MP = PowerManifold(M, 32) |
| 230 | + |
| 231 | + p = rand(MP) |
| 232 | + X = randn(size(p)...) |
| 233 | + |
| 234 | + Y_cpu = similar(X) |
| 235 | + for i in 1:size(p, 3) |
| 236 | + ManifoldsBase.project!( |
| 237 | + M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i) |
| 238 | + ) |
| 239 | + end |
| 240 | + |
| 241 | + p_cu = CuArray(p) |
| 242 | + X_cu = CuArray(X) |
| 243 | + Y_cu = similar(X_cu) |
| 244 | + ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu) |
| 245 | + Y_cu_h = Array(Y_cu) |
| 246 | + |
| 247 | + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0e-14, rtol = 2.0e-14) |
| 248 | + end |
| 249 | + |
| 250 | + @testset "project! tangent Float32" begin |
| 251 | + Random.seed!(79) |
| 252 | + |
| 253 | + M = Stiefel(8, 4) |
| 254 | + MP = PowerManifold(M, 32) |
| 255 | + |
| 256 | + p = Float32.(rand(MP)) |
| 257 | + X = Float32.(randn(size(p)...)) |
| 258 | + |
| 259 | + Y_cpu = similar(X) |
| 260 | + for i in 1:size(p, 3) |
| 261 | + ManifoldsBase.project!( |
| 262 | + M, view(Y_cpu, :, :, i), view(p, :, :, i), view(X, :, :, i) |
| 263 | + ) |
| 264 | + end |
| 265 | + |
| 266 | + p_cu = CuArray(p) |
| 267 | + X_cu = CuArray(X) |
| 268 | + Y_cu = similar(X_cu) |
| 269 | + ManifoldsBase.project!(MP, Y_cu, p_cu, X_cu) |
| 270 | + Y_cu_h = Array(Y_cu) |
| 271 | + |
| 272 | + @test isapprox(Y_cu_h, Y_cpu; atol = 2.0f-5, rtol = 2.0f-5) |
| 273 | + end |
176 | 274 | end |
0 commit comments