Skip to content

Commit 1dbf04b

Browse files
committed
add another state.
1 parent d5da0a3 commit 1dbf04b

4 files changed

Lines changed: 52 additions & 25 deletions

File tree

src/plans/plan.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,21 @@ end
3636
# ---
3737
# check whether a context is inline or less
3838
_is_inline(c) = (c == :inline || c == :short)
39-
# ind_str - indent a string for use within another one
39+
# _in_str - indent a string for use within another one
4040
# * `indent = false` raise indentation by `indent_str` (`_MANOPT_INDENT` by default)
4141
# * `headers = true` increase headers also on Headers that are indented with `indent_str`
42-
function _in_str(s::String; indent = 0, headers = 1, indent_str = _MANOPT_INDENT)
42+
# * `indent_str = _MANOPT_INDENT` string to use for indent
43+
# * `indent_end = ""` a string to end the indentation, for example a `"| "` for visual distinction
44+
function _in_str(s::String; indent = 0, headers = 1, indent_str = _MANOPT_INDENT, indent_end = "")
4345
t = s
46+
#add start
47+
t = replace("$(indent_end)$t", "\n" => "\n$(indent_end)")
48+
#add indent iteratively
4449
for _ in 1:indent
4550
t = replace("$(indent_str)$t", "\n" => "\n$(indent_str)")
4651
end
47-
for i in 1:headers
52+
# increase headers iteratively
53+
for _ in 1:headers
4854
t = replace(t, Regex("(?m)^($(indent_str)*)(#+)") => s"\1#\2")
4955
end
5056
return t

src/solvers/truncated_conjugate_gradient_descent.jl

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,15 @@ mutable struct TruncatedConjugateGradientState{T, R <: Real, SC <: StoppingCrite
9191
StopWhenModelIncreased(),
9292
kwargs...,
9393
) where {T, R <: Real, F}
94-
tcgs = new{T, R, typeof(stopping_criterion), F}()
94+
return TruncatedConjugateGradientState(;
95+
X = X, trust_region_radius = trust_region_radius, randomize = randomize,
96+
(project!) = project!, stopping_criterion = stopping_criterion,
97+
)
98+
end
99+
function TruncatedConjugateGradientState(;
100+
X::T, trust_region_radius::R, randomize::Bool, project!::F, stopping_criterion::SC,
101+
) where {T, R <: Real, F, SC <: StoppingCriterion}
102+
tcgs = new{T, R, SC, F}()
95103
tcgs.stop = stopping_criterion
96104
tcgs.Y = X
97105
tcgs.trust_region_radius = trust_region_radius
@@ -101,12 +109,22 @@ mutable struct TruncatedConjugateGradientState{T, R <: Real, SC <: StoppingCrite
101109
return tcgs
102110
end
103111
end
112+
function Base.show(io::IO, tcgs::TruncatedConjugateGradientState)
113+
print(io, "TruncatedConjugateGradientState(;")
114+
print(io, "(project!) = $(tcgs.project!), ")
115+
print(io, "randomize = $(tcgs.randomize), ")
116+
print(io, "stopping_criterion = $(tcgs.stop), ")
117+
print(io, "trust_region_radius = $(tcgs.trust_region_radius), ")
118+
return print(io, "X = $(tcgs.Y))")
119+
end
104120
function status_summary(tcgs::TruncatedConjugateGradientState; context = :default)
121+
(context === :short) && repr(tcgs)
105122
i = get_count(tcgs, :Iterations)
123+
conv_inl = (i > 0) ? (indicates_convergence(tcgs.stop) ? " (converged" : " (stopped") * " after $i iterations)" : ""
124+
(context === :inline) && "A solver state for the truncated conjugate gradient descent$(conv_inl)"
106125
Iter = (i > 0) ? "After $i iterations\n" : ""
107126
Conv = indicates_convergence(tcgs.stop) ? "Yes" : "No"
108-
_is_inline(context) && (return "$(repr(tcgs))$(Iter) $(has_converged(tcgs) ? "(converged)" : "")")
109-
s = """
127+
return """
110128
# Solver state for `Manopt.jl`s Truncated Conjugate Gradient Descent
111129
$Iter
112130
## Parameters
@@ -116,8 +134,8 @@ function status_summary(tcgs::TruncatedConjugateGradientState; context = :defaul
116134
## Stopping criterion
117135
$(status_summary(tcgs.stop; context = context))
118136
This indicates convergence: $Conv"""
119-
return s
120137
end
138+
121139
function set_parameter!(tcgs::TruncatedConjugateGradientState, ::Val{:Iterate}, Y)
122140
return tcgs.Y = Y
123141
end
@@ -191,10 +209,11 @@ function get_reason(c::StopWhenResidualIsReducedByFactorOrPower)
191209
return ""
192210
end
193211
function status_summary(c::StopWhenResidualIsReducedByFactorOrPower; context = :default)
194-
context === :short && return repr(c)
212+
(context === :short) && (return repr(c))
195213
has_stopped = (c.at_iteration >= 0)
196214
s = has_stopped ? "reached" : "not reached"
197-
return (_is_inline(context) ? "Residual reduced by factor $(c.κ) or power $(c.θ):$(_MANOPT_INDENT)" : "A stopping criterion used within tCG to check whether the residual is reduced by factor $(c.κ) or power 1+$(c.θ)\n$(_MANOPT_INDENT)") * "$s"
215+
(context === :inline) && (return "Residual reduced by factor $(c.κ) or power $(c.θ):$(_MANOPT_INDENT)$s")
216+
return "A stopping criterion used within tCG to check whether the residual is reduced by factor $(c.κ) or power 1+$(c.θ)\n$(_MANOPT_INDENT)$s"
198217
end
199218
function show(io::IO, c::StopWhenResidualIsReducedByFactorOrPower)
200219
return print(io, "StopWhenResidualIsReducedByFactorOrPower($(c.κ), $(c.θ))")
@@ -276,12 +295,13 @@ function get_reason(c::StopWhenTrustRegionIsExceeded)
276295
return ""
277296
end
278297
function status_summary(c::StopWhenTrustRegionIsExceeded; context = :default)
279-
(context == :short) && return repr(c)
298+
(context === :short) && (return repr(c))
280299
has_stopped = (c.at_iteration >= 0)
281300
s = has_stopped ? "reached" : "not reached"
282-
return (_is_inline(context) ? "Trust region exceeded:$(_MANOPT_INDENT)" : "A stopping criterion to stop when the trust region radius ($(c.trr)) is exceeded.\n$(_MANOPT_INDENT)") * "$s"
301+
(context === :inline) && (return "Trust region exceeded:$(_MANOPT_INDENT)$s")
302+
return "A stopping criterion to stop when the trust region radius (0.0) is exceeded.\n$(_MANOPT_INDENT)$s"
283303
end
284-
function show(io::IO, c::StopWhenTrustRegionIsExceeded)
304+
function show(io::IO, ::StopWhenTrustRegionIsExceeded)
285305
return print(io, "StopWhenTrustRegionIsExceeded()")
286306
end
287307

@@ -334,10 +354,11 @@ function get_reason(c::StopWhenCurvatureIsNegative)
334354
return ""
335355
end
336356
function status_summary(c::StopWhenCurvatureIsNegative; context = :default)
337-
(context == :short) && return repr(c)
357+
(context === :short) && (return repr(c))
338358
has_stopped = (c.at_iteration >= 0)
339359
s = has_stopped ? "reached" : "not reached"
340-
return (_is_inline(context) ? "Curvature is negative:$(_MANOPT_INDENT)" : "A stopping criterion to stop when the is negative\n$(_MANOPT_INDENT)") * "$s"
360+
(context === :inline) && (return "Curvature is negative:$(_MANOPT_INDENT)$s")
361+
return "A stopping criterion to stop when the is negative\n$(_MANOPT_INDENT)$s"
341362
end
342363
function show(io::IO, ::StopWhenCurvatureIsNegative)
343364
return print(io, "StopWhenCurvatureIsNegative()")
@@ -351,7 +372,7 @@ A functor for testing if the curvature of the model value increased.
351372
# Fields
352373
353374
$(_fields(:at_iteration))
354-
* `model_value`stre the last model value
375+
* `model_value` store the last model value
355376
* `inc_model_value` store the model value that increased
356377
357378
# Constructor
@@ -391,12 +412,13 @@ function get_reason(c::StopWhenModelIncreased)
391412
return ""
392413
end
393414
function status_summary(c::StopWhenModelIncreased; context = :default)
394-
context === :short && return repr(c)
415+
(context === :short) && (repr(c))
395416
has_stopped = (c.at_iteration >= 0)
396417
s = has_stopped ? "reached" : "not reached"
397-
return "Model Increased:$(_MANOPT_INDENT)$s"
418+
(context === :inline) && (return "Model Increased:$(_MANOPT_INDENT)$s")
419+
return "A stopping criterion to indicate when the model increased.\n$(_MANOPT_INDENT)$s"
398420
end
399-
function show(io::IO, ::StopWhenModelIncreased)
421+
function show(io::IO, c::StopWhenModelIncreased)
400422
return print(io, "StopWhenModelIncreased()")
401423
end
402424

@@ -424,8 +446,8 @@ solve the trust-region subproblem
424446
425447
$(_doc_TCG_subproblem)
426448
427-
on a manifold ``$(_math(:Manifold))nifold)))`` by using the Steihaug-Toint truncated conjugate-gradient (tCG) method.
428-
This can be done inplace of `X`.
449+
on a manifold ``$(_math(:Manifold))`` by using the Steihaug-Toint truncated conjugate-gradient (tCG) method.
450+
This can be done in-place of `X`.
429451
430452
For a description of the algorithm and theorems offering convergence guarantees,
431453
see [AbsilBakerGallivan:2006, ConnGouldToint:2000](@cite).

src/solvers/trust_regions.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ function status_summary(trs::TrustRegionsState; context = :default)
226226
Iter = (i > 0) ? "After $i iterations\n" : ""
227227
Conv = indicates_convergence(trs.stop) ? "Yes" : "No"
228228
_is_inline(context) && (return "$(repr(trs))$(Iter) $(has_converged(trs) ? "(converged)" : "")")
229-
sub = repr(trs.sub_state)
230-
sub = replace(sub, "\n" => "\n | ", "\n#" => "\n$(_MANOPT_INDENT)##")
229+
sub = _in_str(status_summary(trs.sub_state; context = context); indent = 1, headers = 1, indent_end = "| ")
231230
s = """
232231
# Solver state for `Manopt.jl`s Trust Region Method
233232
$Iter
@@ -239,8 +238,8 @@ function status_summary(trs::TrustRegionsState; context = :default)
239238
* retraction method: $(trs.retraction_method)
240239
* ρ_regularization: $(trs.ρ_regularization)
241240
* trust region radius: $(trs.trust_region_radius) (max: $(trs.max_trust_region_radius))
242-
* sub solver state :
243-
| $(sub)
241+
* sub solver state:
242+
$(sub)
244243
245244
## Stopping criterion
246245
$(status_summary(trs.stop; context = context))

test/solvers/test_truncated_cg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ using Manifolds, Manopt, ManifoldsBase, Test
2828
@test repr(scn) == "StopWhenCurvatureIsNegative()"
2929
smi = StopWhenModelIncreased()
3030
smi1 = Manopt.status_summary(smi)
31-
@test smi1 == "Model Increased:$(Manopt._MANOPT_INDENT)not reached"
31+
@test startswith(smi1, "A stopping criterion to indicate when the model increased.")
3232
@test repr(smi) == "StopWhenModelIncreased()"
3333
end

0 commit comments

Comments
 (0)