Skip to content

Commit 1525e4a

Browse files
committed
Improve naming, test coverage and REPL display.
1 parent 8514a44 commit 1525e4a

9 files changed

Lines changed: 161 additions & 39 deletions

File tree

Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
* a `ScaledManifoldObjective` to easier build scaled versions of objectives,
1414
especially turn maximisation problems into minimisation ones using a scaling of `-1`.
15-
* Introduce a `ConstrainedSetObjective`
15+
* Introduce a `ManifoldConstrainedSetObjective`
1616
* Introduce a `projected_gradient_method`
1717

1818

docs/src/plans/objective.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ linearized_forward_operator
205205

206206
```@docs
207207
ConstrainedManifoldObjective
208-
ConstrainedSetObjective
208+
ManifoldConstrainedSetObjective
209209
```
210210

211211
It might be beneficial to use the adapted problem to specify different ranges for the gradients of the constraints

src/Manopt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ export AbstractDecoratedManifoldObjective,
297297
AbstractManifoldSubObjective,
298298
AbstractPrimalDualManifoldObjective,
299299
ConstrainedManifoldObjective,
300+
ManifoldConstrainedSetObjective,
300301
EmbeddedManifoldObjective,
301302
ScaledManifoldObjective,
302303
ManifoldCountObjective,

src/plans/constrained_set_plan.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ConstrainedSetObjective{E, MO, PF, IF} <: AbstractManifoldObjective{E}
2+
ManifoldConstrainedSetObjective{E, MO, PF, IF} <: AbstractManifoldObjective{E}
33
44
Model a constrained objective restricted to a set
55
@@ -18,7 +18,7 @@ where ``$(_tex(:Cal,"C")) ⊂ $(_math(:M))`` is a convex closed subset.
1818
1919
# Constructor
2020
21-
ConstrainedSetObjective(f, grad_f, project!!; kwargs...)
21+
ManifoldConstrainedSetObjective(f, grad_f, project!!; kwargs...)
2222
2323
Generate the constrained objective for a given function `f` its gradient `grad_f` and a projection `project!!` ``$(_tex(:proj))_{$(_tex(:Cal,"C"))}``.
2424
@@ -28,22 +28,24 @@ $(_var(:Keyword, :evaluation))
2828
* `indicator=nothing`: the indicator function ``ι_{$(_tex(:Cal,"C"))}(p)``. If not provided a test, whether the projection yields the same point is performed.
2929
For the [`InplaceEvaluation`](@ref) this required one allocation.
3030
"""
31-
struct ConstrainedSetObjective{
31+
struct ManifoldConstrainedSetObjective{
3232
E<:AbstractEvaluationType,MO<:AbstractManifoldObjective,PF,IF
3333
} <: AbstractManifoldObjective{E}
3434
objective::MO
3535
project!!::PF
3636
indicator::IF
3737
end
3838

39-
function ConstrainedSetObjective(
39+
function ManifoldConstrainedSetObjective(
4040
f, grad_f, project!!::PF; evaluation::E=AllocatingEvaluation(), indicator=nothing
4141
) where {PF,E<:AbstractEvaluationType}
4242
obj = ManifoldGradientObjective(f, grad_f; evaluation=evaluation)
4343
if isnothing(indicator)
4444
if evaluation isa AllocatingEvaluation
4545
ind(M, p) = (distance(M, p, project!!(M, p)) 0 ? 0 : Inf)
46-
return ConstrainedSetObjective{E,typeof(obj),typeof(project!!),typeof(ind)}(
46+
return ManifoldConstrainedSetObjective{
47+
E,typeof(obj),typeof(project!!),typeof(ind)
48+
}(
4749
obj, project!!, ind
4850
)
4951
elseif evaluation isa InplaceEvaluation
@@ -52,39 +54,43 @@ function ConstrainedSetObjective(
5254
project!!(M, q, p)
5355
return distance(M, p, q) 0 ? 0 : Inf
5456
end
55-
return ConstrainedSetObjective{E,typeof(obj),typeof(project!!),typeof(ind)}(
57+
return ManifoldConstrainedSetObjective{
58+
E,typeof(obj),typeof(project!!),typeof(ind)
59+
}(
5660
obj, project!!, ind
5761
)
5862
end
5963
end
60-
return ConstrainedSetObjective{E,typeof(obj),typeof(project!!),typeof(indicator)}(
64+
return ManifoldConstrainedSetObjective{
65+
E,typeof(obj),typeof(project!!),typeof(indicator)
66+
}(
6167
obj, project!!, indicator
6268
)
6369
end
6470

65-
function get_cost(M::AbstractManifold, cso::ConstrainedSetObjective, p)
71+
function get_cost(M::AbstractManifold, cso::ManifoldConstrainedSetObjective, p)
6672
return get_cost(M, cso.objective, p)
6773
end
68-
function get_cost_function(cso::ConstrainedSetObjective, recursive=false)
69-
return get_cost_function(cso.objective)
74+
function get_cost_function(cso::ManifoldConstrainedSetObjective, recursive=false)
75+
return get_cost_function(cso.objective, recursive)
7076
end
71-
function get_gradient_function(cso::ConstrainedSetObjective, recursive=false)
72-
return get_gradient_function(cso.objective)
77+
function get_gradient_function(cso::ManifoldConstrainedSetObjective, recursive=false)
78+
return get_gradient_function(cso.objective, recursive)
7379
end
74-
function get_gradient(M::AbstractManifold, cso::ConstrainedSetObjective, p)
80+
function get_gradient(M::AbstractManifold, cso::ManifoldConstrainedSetObjective, p)
7581
return get_gradient(M, cso.objective, p)
7682
end
77-
function get_gradient!(M::AbstractManifold, X, cso::ConstrainedSetObjective, p)
83+
function get_gradient!(M::AbstractManifold, X, cso::ManifoldConstrainedSetObjective, p)
7884
return get_gradient!(M, X, cso.objective, p)
7985
end
8086

8187
_doc_get_projected_point = """
8288
get_projected_point(amp::AbstractManoptProblem, p)
8389
get_projected_point!(amp::AbstractManoptProblem, q, p)
84-
get_projected_point(M::AbstractManifold, cso::ConstrainedSetObjective, p)
85-
get_projected_point!(M::AbstractManifold, q, cso::ConstrainedSetObjective, p)
90+
get_projected_point(M::AbstractManifold, cso::ManifoldConstrainedSetObjective, p)
91+
get_projected_point!(M::AbstractManifold, q, cso::ManifoldConstrainedSetObjective, p)
8692
87-
Project `p` with the projection that is stored within the [`ConstrainedSetObjective`](@ref).
93+
Project `p` with the projection that is stored within the [`ManifoldConstrainedSetObjective`](@ref).
8894
This can be done in-place of `q`.
8995
"""
9096

@@ -98,29 +104,29 @@ function get_projected_point!(amp::AbstractManoptProblem, q, p)
98104
end
99105

100106
@doc "$(_doc_get_projected_point)"
101-
get_projected_point(M::AbstractManifold, cso::ConstrainedSetObjective, p)
107+
get_projected_point(M::AbstractManifold, cso::ManifoldConstrainedSetObjective, p)
102108
function get_projected_point(
103-
M::AbstractManifold, cso::ConstrainedSetObjective{AllocatingEvaluation}, p
109+
M::AbstractManifold, cso::ManifoldConstrainedSetObjective{AllocatingEvaluation}, p
104110
)
105111
return cso.project!!(M, p)
106112
end
107113
function get_projected_point(
108-
M::AbstractManifold, cso::ConstrainedSetObjective{InplaceEvaluation}, p
114+
M::AbstractManifold, cso::ManifoldConstrainedSetObjective{InplaceEvaluation}, p
109115
)
110116
q = copy(M, p)
111117
cso.project!!(M, q, p)
112118
return q
113119
end
114120
@doc "$(_doc_get_projected_point)"
115-
get_projected_point!(M::AbstractManifold, q, cso::ConstrainedSetObjective, p)
121+
get_projected_point!(M::AbstractManifold, q, cso::ManifoldConstrainedSetObjective, p)
116122
function get_projected_point!(
117-
M::AbstractManifold, q, cso::ConstrainedSetObjective{AllocatingEvaluation}, p
123+
M::AbstractManifold, q, cso::ManifoldConstrainedSetObjective{AllocatingEvaluation}, p
118124
)
119125
copyto!(M, q, cso.project!!(M, p))
120126
return q
121127
end
122128
function get_projected_point!(
123-
M::AbstractManifold, q, cso::ConstrainedSetObjective{InplaceEvaluation}, p
129+
M::AbstractManifold, q, cso::ManifoldConstrainedSetObjective{InplaceEvaluation}, p
124130
)
125131
cso.project!!(M, q, p)
126132
return q

src/plans/gradient_plan.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ function ManifoldCostGradientObjective(
9292
return ManifoldCostGradientObjective{typeof(evaluation),CG}(costgrad)
9393
end
9494

95-
get_cost_function(cgo::ManifoldCostGradientObjective) = (M, p) -> get_cost(M, cgo, p)
96-
function get_gradient_function(cgo::ManifoldCostGradientObjective)
95+
function get_cost_function(cgo::ManifoldCostGradientObjective, recursive=false)
96+
return (M, p) -> get_cost(M, cgo, p)
97+
end
98+
function get_gradient_function(cgo::ManifoldCostGradientObjective, recursive=false)
9799
return (M, p) -> get_gradient(M, cgo, p)
98100
end
99101

src/solvers/projected_gradient_method.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,25 @@ function get_reason(c::StopWhenProjectedGradientStationary)
142142
end
143143
return ""
144144
end
145+
indicates_convergence(c::StopWhenProjectedGradientStationary) = true
146+
function show(io::IO, c::StopWhenProjectedGradientStationary)
147+
return print(
148+
io, "StopWhenProjectedGradientStationary($(c.threshold))\n $(status_summary(c))"
149+
)
150+
end
145151
function status_summary(c::StopWhenProjectedGradientStationary)
146152
has_stopped = (c.at_iteration >= 0)
147153
s = has_stopped ? "reached" : "not reached"
148-
return "stopped after $(c.threshold):\t$s"
154+
return "projected gradient stationary (<$(c.threshold)): \t$s"
149155
end
150-
151156
#
152157
#
153158
# The solver
154159
_doc_pgm = """
155160
projected_gradient_method(M, f, grad_f, proj, p=rand(M); kwargs...)
156-
projected_gradient_method(M, obj::ConstrainedSetObjective, p=rand(M); kwargs...)
161+
projected_gradient_method(M, obj::ManifoldConstrainedSetObjective, p=rand(M); kwargs...)
157162
projected_gradient_method!(M, f, grad_f, proj, p; kwargs...)
158-
projected_gradient_method!(M, obj::ConstrainedSetObjective, p; kwargs...)
163+
projected_gradient_method!(M, obj::ManifoldConstrainedSetObjective, p; kwargs...)
159164
160165
Compute the projected gradient method for the constrained problem
161166
@@ -202,12 +207,12 @@ end
202207
function projected_gradient_method(
203208
M, f, grad_f, proj, p; indicator=nothing, evaluation=AllocatingEvaluation(), kwargs...
204209
)
205-
cs_obj = ConstrainedSetObjective(
210+
cs_obj = ManifoldConstrainedSetObjective(
206211
f, grad_f, proj; evaluation=evaluation, indicator=indicator
207212
)
208213
return projected_gradient_method(M, cs_obj, p; kwargs...)
209214
end
210-
function projected_gradient_method(M, obj::ConstrainedSetObjective, p; kwargs...)
215+
function projected_gradient_method(M, obj::ManifoldConstrainedSetObjective, p; kwargs...)
211216
q = copy(M, p)
212217
return projected_gradient_method!(M, obj, q; kwargs...)
213218
end
@@ -216,14 +221,14 @@ end
216221
function projected_gradient_method!(
217222
M, f, grad_f, proj, p; indicator=nothing, evaluation=AllocatingEvaluation(), kwargs...
218223
)
219-
cs_obj = ConstrainedSetObjective(
224+
cs_obj = ManifoldConstrainedSetObjective(
220225
f, grad_f, proj; evaluation=evaluation, indicator=indicator
221226
)
222227
return projected_gradient_method!(M, cs_obj, p; kwargs...)
223228
end
224229
function projected_gradient_method!(
225230
M,
226-
obj::ConstrainedSetObjective,
231+
obj::ManifoldConstrainedSetObjective,
227232
p;
228233
backtrack::Stepsize=ArmijoLinesearchStepsize(M; stop_increasing_at_step=0),
229234
retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)),
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using Manifolds, Manopt, Random, Test
2+
3+
@testset "Constained set objective" begin
4+
M = Hyperbolic(2)
5+
c = Manifolds._hyperbolize(M, [0, 0])
6+
r = 1.0
7+
N = 200
8+
σ = 1.5
9+
Random.seed!(42)
10+
# N random points moved to top left to have a mean outside
11+
pts = [
12+
exp(
13+
M,
14+
c,
15+
get_vector(
16+
M,
17+
c,
18+
σ .* randn(manifold_dimension(M)) .+ [2.5, 2.5],
19+
DefaultOrthonormalBasis(),
20+
),
21+
) for _ in 1:N
22+
]
23+
f(M, p) = 1 / (2 * length(pts)) .* sum(distance(M, p, q)^2 for q in pts)
24+
grad_f(M, p) = -1 / length(pts) .* sum(log(M, p, q) for q in pts)
25+
function grad_f!(M, X, p)
26+
zero_vector!(M, X, p)
27+
Y = zero_vector(M, p)
28+
for q in pts
29+
log!(M, Y, p, q)
30+
X .+= Y
31+
end
32+
X .*= -1 / length(pts)
33+
return X
34+
end
35+
function project_C(M, p)
36+
X = log(M, c, p)
37+
n = norm(M, c, X)
38+
q = (n > r) ? exp(M, c, (r / n) * X) : copy(M, p)
39+
return q
40+
end
41+
function project_C!(M, q, p; X=zero_vector(M, c))
42+
log!(M, X, c, p)
43+
n = norm(M, c, X)
44+
if (n > r)
45+
exp!(M, q, c, (r / n) * X)
46+
else
47+
copyto!(M, q, p)
48+
end
49+
return q
50+
end
51+
g(M, p) = distance(M, c, p)^2 - r^2
52+
indicator_C(M, p) = (g(M, p) 0) ? 0 : Inf
53+
54+
csoa = ManifoldConstrainedSetObjective(f, grad_f, project_C)
55+
csoa2 = ManifoldConstrainedSetObjective(f, grad_f, project_C; indicator=indicator_C)
56+
csoi = ManifoldConstrainedSetObjective(
57+
f, grad_f!, project_C!; evaluation=InplaceEvaluation()
58+
)
59+
csoi2 = ManifoldConstrainedSetObjective(
60+
f, grad_f!, project_C!; evaluation=InplaceEvaluation(), indicator=indicator_C
61+
)
62+
63+
for objective in [csoa, csoa2, csoi, csoi2]
64+
@test get_cost(M, objective, c) == f(M, c)
65+
@test Manopt.get_cost_function(objective)(M, c) == f(M, c)
66+
@test get_gradient(M, objective, c) == grad_f(M, c)
67+
X = zero_vector(M, c)
68+
get_gradient!(M, X, objective, c)
69+
Y = zero_vector(M, c)
70+
grad_f!(M, Y, c)
71+
@test X == Y
72+
if objective [csoa, csoa2]
73+
@test Manopt.get_gradient_function(objective)(M, c) == grad_f(M, c)
74+
else
75+
Manopt.get_gradient_function(objective)(M, X, c) == grad_f!(M, Y, c)
76+
@test X == Y
77+
end
78+
dmp = DefaultManoptProblem(M, objective)
79+
p = get_projected_point(dmp, c)
80+
@test p == c # c is already in C
81+
get_projected_point!(dmp, p, c)
82+
@test p == c
83+
p = get_projected_point(M, objective, c)
84+
@test p == c
85+
get_projected_point!(M, p, objective, c)
86+
@test p == c
87+
end
88+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include("utils/example_tasks.jl")
2020
include("plans/test_nonlinear_least_squares_plan.jl")
2121
include("plans/test_gradient_plan.jl")
2222
include("plans/test_constrained_plan.jl")
23+
include("plans/test_constrained_set_plan.jl")
2324
include("plans/test_hessian_plan.jl")
2425
include("plans/test_parameters.jl")
2526
include("plans/test_primal_dual_plan.jl")

test/solvers/test_projected_gradient.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,35 @@ using Manifolds, Manopt, Random, Test
5757
stopping_criterion=StopAfterIteration(150) |
5858
StopWhenProjectedGradientStationary(M, 1e-7),
5959
)
60-
mean_pg_2 = copy(M, c)
61-
projected_gradient_method!(
60+
Random.seed!(42)
61+
mean_pg_2 = projected_gradient_method(
62+
M,
63+
f,
64+
grad_f,
65+
project_C;
66+
stopping_criterion=StopAfterIteration(150) |
67+
StopWhenProjectedGradientStationary(M, 1e-7),
68+
)
69+
@test isapprox(M, mean_pg_1, mean_pg_2)
70+
mean_pg_3 = copy(M, c)
71+
st = projected_gradient_method!(
6272
M,
6373
f,
6474
grad_f!,
6575
project_C!,
66-
mean_pg_2;
76+
mean_pg_3;
6777
evaluation=InplaceEvaluation(),
6878
stopping_criterion=StopAfterIteration(150) |
6979
StopWhenProjectedGradientStationary(M, 1e-7),
80+
return_state=true,
7081
)
71-
@test isapprox(M, mean_pg_1, mean_pg_2)
82+
@test isapprox(M, mean_pg_1, mean_pg_3)
83+
@test startswith(
84+
repr(st), "# Solver state for `Manopt.jl`s Projected Gradient Method\n"
85+
)
86+
stop_when_stationary = st.stop.criteria[2]
87+
@test repr(stop_when_stationary) ==
88+
"StopWhenProjectedGradientStationary($(stop_when_stationary.threshold))\n $(Manopt.status_summary(
89+
stop_when_stationary))"
90+
@test length(get_reason(stop_when_stationary)) > 0
7291
end

0 commit comments

Comments
 (0)