|
1 | 1 | using Manopt, Manifolds, ManifoldsBase, Test, Random, LinearAlgebra |
2 | 2 | using LinearAlgebra: Diagonal, dot, eigvals, eigvecs |
| 3 | +using ManifoldDiff: grad_distance |
3 | 4 |
|
4 | 5 | @testset "Conjugate Gradient Descent" begin |
5 | 6 | @testset "Conjugate Gradient coefficient rules" begin |
@@ -31,7 +32,7 @@ using LinearAlgebra: Diagonal, dot, eigvals, eigvecs |
31 | 32 | initial_gradient = zero_vector(M, x0), |
32 | 33 | ) |
33 | 34 | @test s1.coefficient(dmp, s1, 1) == 0 |
34 | | - @test default_stepsize(M, typeof(s1)) isa Manopt.ArmijoLinesearchStepsize |
| 35 | + @test default_stepsize(M, typeof(s1)) isa Manopt.ManifoldDefaultsFactory{Manopt.ArmijoLinesearchStepsize} |
35 | 36 | @test Manopt.get_message(s1) == "" |
36 | 37 |
|
37 | 38 | dU = Manopt.ConjugateDescentCoefficient() |
@@ -394,4 +395,13 @@ using LinearAlgebra: Diagonal, dot, eigvals, eigvecs |
394 | 395 | ) |
395 | 396 | @test q2 ≈ [1, 0, 0] rtol = 1.0e-7 |
396 | 397 | end |
| 398 | + |
| 399 | + @testset "Custom point types" begin |
| 400 | + M = Hyperbolic(2) |
| 401 | + data = PoincareBallPoint.([[0.1, 0.2], [0.3, 0.25], [0.35, 0.4]]) |
| 402 | + n = length(data) |
| 403 | + f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2) |
| 404 | + grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p))) |
| 405 | + @test conjugate_gradient_descent(M, f, grad_f, data[1]) isa PoincareBallPoint |
| 406 | + end |
397 | 407 | end |
0 commit comments