Skip to content

Commit a765107

Browse files
committed
Adapt the Trust RegionsState, start with MADS.
1 parent dcd213d commit a765107

4 files changed

Lines changed: 100 additions & 92 deletions

File tree

src/plans/mesh_adaptive_plan.jl

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,8 @@ $(_fields([:retraction_method, :vector_transport_method]))
9999
$(_kwargs([:retraction_method, :vector_transport_method, :X]))
100100
"""
101101
mutable struct LowerTriangularAdaptivePoll{
102-
P,
103-
T,
104-
F <: Real,
105-
V <: AbstractVector{F},
106-
M <: AbstractMatrix{F},
107-
I <: Int,
108-
B,
109-
VTM <: AbstractVectorTransportMethod,
110-
RM <: AbstractRetractionMethod,
102+
P, T, F <: Real, V <: AbstractVector{F}, M <: AbstractMatrix{F},
103+
I <: Int, B, VTM <: AbstractVectorTransportMethod, RM <: AbstractRetractionMethod,
111104
} <: AbstractMeshPollFunction
112105
base_point::P
113106
candidate::P
@@ -353,12 +346,8 @@ function show(io::IO, dmads::DefaultMeshAdaptiveDirectSearch)
353346
return print(io, s)
354347
end
355348
function (dmads::DefaultMeshAdaptiveDirectSearch)(
356-
amp::AbstractManoptProblem,
357-
mesh_size::Real,
358-
p,
359-
X;
360-
scale_mesh::Real = 1.0,
361-
max_stepsize::Real = Inf,
349+
amp::AbstractManoptProblem, mesh_size::Real, p, X;
350+
scale_mesh::Real = 1.0, max_stepsize::Real = Inf,
362351
)
363352
M = get_manifold(amp)
364353
dmads.X .= (4 * mesh_size * scale_mesh) .* X
@@ -391,11 +380,7 @@ $(_fields(:stopping_criterion; name = "stop"))
391380
392381
"""
393382
mutable struct MeshAdaptiveDirectSearchState{
394-
P,
395-
F <: Real,
396-
PT <: AbstractMeshPollFunction,
397-
ST <: AbstractMeshSearchFunction,
398-
SC <: StoppingCriterion,
383+
P, F <: Real, PT <: AbstractMeshPollFunction, ST <: AbstractMeshSearchFunction, SC <: StoppingCriterion,
399384
} <: AbstractManoptSolverState
400385
p::P
401386
mesh_size::F
@@ -407,8 +392,7 @@ mutable struct MeshAdaptiveDirectSearchState{
407392
search::ST
408393
end
409394
function MeshAdaptiveDirectSearchState(
410-
M::AbstractManifold,
411-
p::P = rand(M);
395+
M::AbstractManifold, p::P = rand(M);
412396
mesh_basis::B = default_basis(M, typeof(p)),
413397
scale_mesh::F = injectivity_radius(M) / 2,
414398
max_stepsize::F = injectivity_radius(M),
@@ -428,12 +412,8 @@ function MeshAdaptiveDirectSearchState(
428412
M, copy(M, p); retraction_method = retraction_method
429413
),
430414
) where {
431-
P,
432-
F,
433-
PT <: AbstractMeshPollFunction,
434-
ST <: AbstractMeshSearchFunction,
435-
SC <: StoppingCriterion,
436-
B <: AbstractBasis,
415+
P, F, PT <: AbstractMeshPollFunction, ST <: AbstractMeshSearchFunction,
416+
SC <: StoppingCriterion, B <: AbstractBasis,
437417
}
438418
poll_s = manifold_dimension(M) * 1.0
439419
return MeshAdaptiveDirectSearchState{P, F, PT, ST, SC}(
@@ -443,9 +423,13 @@ end
443423
get_iterate(mads::MeshAdaptiveDirectSearchState) = mads.p
444424

445425
function status_summary(mads::MeshAdaptiveDirectSearchState; context::Symbol = :default)
446-
i = get_count(mads, :Iterations)
426+
(context === :short) && repr(mads)
427+
i = get_count(trs, :Iterations)
428+
conv_inl = (i > 0) ? (indicates_convergence(trs.stop) ? " (converged" : " (stopped") * " after $i iterations)" : ""
429+
(context === :inline) && return "A solver state for the trust region solver$(conv_inl)"
447430
Iter = (i > 0) ? "After $i iterations\n" : ""
448-
_is_inline(context) && (return "$(repr(mads))$(Iter) $(has_converged(mads) ? "(converged)" : "")")
431+
Conv = indicates_convergence(trs.stop) ? "Yes" : "No"
432+
(context === :inline) && (return "A trust regions method state – $(Iter) $(has_converged(trs) ? "(converged)" : "")")
449433
s = """
450434
# Solver state for `Manopt.jl`s mesh adaptive direct search
451435
$Iter
@@ -454,11 +438,12 @@ function status_summary(mads::MeshAdaptiveDirectSearchState; context::Symbol = :
454438
* scale_mesh: $(mads.scale_mesh)
455439
* max_stepsize: $(mads.max_stepsize)
456440
* poll_size: $(mads.poll_size)
457-
* poll:\n $(replace(repr(mads.poll), "\n" => "\n ")[1:(end - 3)])
458-
* search:\n $(replace(repr(mads.search), "\n" => "\n ")[1:(end - 3)])
441+
* poll:\n $(_in_str(repr(mads.poll); indent = 1))
442+
* search:\n $(_in_str(repr(mads.search); indent = 1))
459443
460444
## Stopping criterion
461445
$(status_summary(mads.stop; context = context))
446+
This indicates convergence: $Conv
462447
"""
463448
return s
464449
end

src/solvers/exact_penalty_method.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function status_summary(epms::ExactPenaltyMethodState; context::Symbol = :defaul
126126
(context === :inline) && return "A solver state for the exact panelty method$(conv_inl)"
127127
Iter = (i > 0) ? "After $i iterations\n" : ""
128128
Conv = indicates_convergence(epms.stop) ? "Yes" : "No"
129-
_is_inline(context) && (return "$(repr(epms))$(Iter) $(has_converged(epms) ? "(converged)" : "")")
129+
(context === :inline) && (return "An exact penalty method state$(Iter) $(has_converged(epms) ? "(converged)" : "")")
130130
s = """
131131
# Solver state for `Manopt.jl`s Exact Penalty Method
132132
$Iter

src/solvers/trust_regions.jl

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -75,36 +75,33 @@ mutable struct TrustRegionsState{
7575
sub_state::St
7676
p_proposal::P
7777
f_proposal::R
78+
σ::R
79+
reduction_threshold::R
80+
reduction_factor::R
81+
augmentation_threshold::R
82+
augmentation_factor::R
7883
# Only required for Random mode Random
7984
HX::T
8085
Y::T
8186
HY::T
8287
Z::T
8388
HZ::T
8489
τ::R
85-
σ::R
86-
reduction_threshold::R
87-
reduction_factor::R
88-
augmentation_threshold::R
89-
augmentation_factor::R
90-
function TrustRegionsState{P, T, Pr, St, SC, RTR, R, Proj}(
91-
p::P,
92-
X::T,
93-
trust_region_radius::R,
94-
max_trust_region_radius::R,
95-
acceptance_rate::R,
96-
ρ_regularization::R,
97-
randomize::Bool,
98-
stopping_criterion::SC,
99-
retraction_method::RTR,
100-
reduction_threshold::R,
101-
augmentation_threshold::R,
102-
sub_problem::Pr,
103-
sub_state::St,
104-
project!::Proj = (copyto!),
105-
reduction_factor = 0.25,
106-
augmentation_factor = 2.0,
107-
σ::R = random ? 1.0e-6 : 0.0,
90+
function TrustRegionsState(
91+
sub_problem::Pr, sub_state::St;
92+
p::P, X::T,
93+
trust_region_radius::R, max_trust_region_radius::R, acceptance_rate::R,
94+
ρ_regularization::R, randomize::Bool,
95+
stopping_criterion::SC, retraction_method::RTR, reduction_threshold::R,
96+
augmentation_threshold::R, project!::Proj = (copyto!),
97+
reduction_factor::R, augmentation_factor::R, σ::R,
98+
#random mode ones can stay uninitielized if not provided
99+
HX::Union{T, Nothing} = nothing,
100+
Y::Union{T, Nothing} = nothing,
101+
HY::Union{T, Nothing} = nothing,
102+
Z::Union{T, Nothing} = nothing,
103+
HZ::Union{T, Nothing} = nothing,
104+
τ::Union{R, Nothing} = nothing,
108105
) where {
109106
P, T, Pr, St <: AbstractManoptSolverState,
110107
SC <: StoppingCriterion, RTR <: AbstractRetractionMethod, R <: Real, Proj,
@@ -127,51 +124,52 @@ mutable struct TrustRegionsState{
127124
trs.augmentation_factor = augmentation_factor
128125
trs.project! = project!
129126
trs.σ = σ
127+
!isnothing(HX) && (trs.HX = HX)
128+
!isnothing(Y) && (trs.Y = Y)
129+
!isnothing(HY) && (trs.HY = HY)
130+
!isnothing(Z) && (trs.HZ = Z)
131+
!isnothing(HZ) && (trs.HZ = HZ)
132+
!isnothing(τ) && (trs.τ = τ)
130133
return trs
131134
end
132135
end
133-
136+
TrustRegionsState(M::AbstractManifold, st::AbstractManoptSolverState; kwargs...) = error("Trust region method state can not be constructed based on $M and the sub state $st, a sub_problem is missing")
134137
function TrustRegionsState(
135138
M::AbstractManifold, sub_problem::Pr, sub_state::St;
136139
p::P = rand(M), X::T = zero_vector(M, p),
137-
acceptance_rate = 0.1,
138-
ρ_regularization::R = 1000.0,
140+
acceptance_rate::Real = 0.1, ρ_regularization::Real = 1000.0,
139141
randomize::Bool = false,
140142
stopping_criterion::SC = StopAfterIteration(1000) | StopWhenGradientNormLess(1.0e-6),
141-
max_trust_region_radius::R = sqrt(manifold_dimension(M)),
142-
trust_region_radius::R = max_trust_region_radius / 8,
143+
max_trust_region_radius::Real = sqrt(manifold_dimension(M)),
144+
trust_region_radius::Real = max_trust_region_radius / 8,
143145
retraction_method::RTR = default_retraction_method(M, typeof(p)),
144-
reduction_threshold::R = 0.1,
145-
reduction_factor = 0.25,
146-
augmentation_threshold::R = 0.75,
147-
augmentation_factor = 2.0,
148-
project!::Proj = (copyto!),
149-
σ = randomize ? 1.0e-4 : 0.0,
146+
reduction_threshold::Real = 0.1, reduction_factor = 0.25,
147+
augmentation_threshold::Real = 0.75, augmentation_factor::Real = 2.0,
148+
project!::Proj = (copyto!), σ::Real = randomize ? 1.0e-4 : 0.0,
150149
) where {
151150
P, T, Pr <: Union{AbstractManoptProblem, F} where {F}, St <: AbstractManoptSolverState,
152-
R <: Real,
153-
SC <: StoppingCriterion,
154-
RTR <: AbstractRetractionMethod,
155-
Proj,
151+
SC <: StoppingCriterion, RTR <: AbstractRetractionMethod, Proj,
156152
}
157-
return TrustRegionsState{P, T, Pr, St, SC, RTR, R, Proj}(
158-
p,
159-
X,
160-
trust_region_radius,
161-
max_trust_region_radius,
162-
acceptance_rate,
163-
ρ_regularization,
164-
randomize,
165-
stopping_criterion,
166-
retraction_method,
167-
reduction_threshold,
168-
augmentation_threshold,
169-
sub_problem,
170-
sub_state,
171-
project!,
172-
reduction_factor,
173-
augmentation_factor,
174-
σ,
153+
R = promote_type(
154+
typeof(acceptance_rate), typeof(ρ_regularization), typeof(max_trust_region_radius),
155+
typeof(trust_region_radius), typeof(reduction_threshold), typeof(reduction_factor),
156+
typeof(augmentation_factor), typeof(augmentation_threshold), typeof(σ)
157+
)
158+
acceptance_rate = convert(R, acceptance_rate); ρ_regularization = convert(R, ρ_regularization)
159+
max_trust_region_radius = convert(R, max_trust_region_radius); trust_region_radius = convert(R, trust_region_radius)
160+
reduction_threshold = convert(R, reduction_threshold); reduction_factor = convert(R, reduction_factor)
161+
augmentation_factor = convert(R, augmentation_factor); augmentation_threshold = convert(R, augmentation_threshold)
162+
σ = convert(R, σ)
163+
164+
return TrustRegionsState(
165+
sub_problem, sub_state;
166+
p = p, X = X,
167+
trust_region_radius = trust_region_radius, max_trust_region_radius = max_trust_region_radius,
168+
acceptance_rate = acceptance_rate, ρ_regularization = ρ_regularization,
169+
(project!) = project!, randomize = randomize, σ = σ,
170+
stopping_criterion = stopping_criterion, retraction_method = retraction_method,
171+
reduction_threshold = reduction_threshold, augmentation_threshold = augmentation_threshold,
172+
reduction_factor = reduction_factor, augmentation_factor = augmentation_factor,
175173
)
176174
end
177175
function TrustRegionsState(
@@ -203,14 +201,31 @@ function get_message(dcs::TrustRegionsState)
203201
# for now only the sub solver might have messages
204202
return get_message(dcs.sub_state)
205203
end
204+
function Base.show(io::IO, trs::TrustRegionsState)
205+
print(io, "TrustRegionsState("); print(io, trs.sub_problem); print(io, ", "); print(io, trs.sub_state)
206+
print(io, "; ")
207+
print(io, "p = $(trs.p), X = $(trs.X), ")
208+
print(io, "trust_region_radius = $(trs.trust_region_radius), max_trust_region_radius = $(trs.max_trust_region_radius), ")
209+
print(io, "acceptance_rate = $(trs.acceptance_rate), ρ_regularization = $(trs.ρ_regularization), randomize = $(trs.randomize), ")
210+
print(io, "reduction_threshold = $(trs.reduction_threshold), augmentation_threshold = $(trs.augmentation_threshold), ")
211+
print(io, "(project!) = $(trs.project!), reduction_factor = $(trs.reduction_factor), augmentation_factor = $(trs.augmentation_factor), σ = $(trs.σ), ")
212+
isdefined(trs, :HX) && print(io, "HX = $(trs.HX), ")
213+
isdefined(trs, :Y) && print(io, "Y = $(trs.Y), ")
214+
isdefined(trs, :HY) && print(io, "HY = $(trs.HY), ")
215+
isdefined(trs, :Z) && print(io, "Z = $(trs.Z), ")
216+
isdefined(trs, :HZ) && print(io, "HZ = $(trs.HZ), ")
217+
isdefined(trs, ) && print(io, "τ = $(trs.τ), ")
218+
print(io, "stopping_criterion = $(trs.stop), retraction_method = $(trs.retraction_method)")
219+
return print(io, ")")
220+
end
206221
function status_summary(trs::TrustRegionsState; context::Symbol = :default)
207222
(context === :short) && return repr(trs)
208223
i = get_count(trs, :Iterations)
209224
conv_inl = (i > 0) ? (indicates_convergence(trs.stop) ? " (converged" : " (stopped") * " after $i iterations)" : ""
210225
(context === :inline) && return "A solver state for the trust region solver$(conv_inl)"
211226
Iter = (i > 0) ? "After $i iterations\n" : ""
212227
Conv = indicates_convergence(trs.stop) ? "Yes" : "No"
213-
_is_inline(context) && (return "$(repr(trs))$(Iter) $(has_converged(trs) ? "(converged)" : "")")
228+
(context === :inline) && (return "A trust regions method state$(Iter) $(has_converged(trs) ? "(converged)" : "")")
214229
sub = _in_str(status_summary(trs.sub_state; context = context); indent = 1, headers = 1, indent_end = "| ")
215230
s = """
216231
# Solver state for `Manopt.jl`s Trust Region Method

test/solvers/test_trust_regions.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("trust_region_model.jl")
3030
sub_state = TruncatedConjugateGradientState(TpM; X = get_gradient(M, mho, p))
3131
trs1 = TrustRegionsState(M, sub_problem)
3232
trs2 = TrustRegionsState(M, sub_problem, sub_state)
33+
@test_throws ErrorException TrustRegionsState(M, sub_state)
3334
trs3 = TrustRegionsState(M, sub_problem; p = p)
3435
@test Manopt.get_gradient_function(sub_objective)(M, p) == X
3536
end
@@ -58,17 +59,24 @@ include("trust_region_model.jl")
5859
Manopt.status_summary(s; context = :default),
5960
"# Solver state for `Manopt.jl`s Trust Region Method\n"
6061
)
62+
@test startswith(repr(s), "TrustRegionsState(")
63+
# not a random one -> does not contain HZ
64+
@test !contains(repr(s), "HZ = ")
6165
p1 = get_solver_result(s)
6266
q = copy(M, p)
6367
set_gradient!(s, M, p, zero_vector(M, p))
6468
@test norm(M, p, get_gradient(s)) 0.0
6569
trust_regions!(M, f, rgrad, rhess, q; max_trust_region_radius = 8.0)
6670
@test isapprox(M, p1, q)
6771
Random.seed!(42)
68-
p2 = trust_regions(
69-
M, f, rgrad, rhess, p; max_trust_region_radius = 8.0, randomize = true
72+
s2 = trust_regions(
73+
M, f, rgrad, rhess, p; max_trust_region_radius = 8.0, randomize = true, return_state = true
7074
)
75+
@test startswith(repr(s2), "TrustRegionsState(")
76+
# a random one -> does contain HZ
77+
@test contains(repr(s2), "HZ = ")
7178

79+
p2 = get_solver_result(s2)
7280
@test f(M, p2) f(M, p1)
7381

7482
p3 = trust_regions(

0 commit comments

Comments
 (0)