Skip to content

Commit 8e1790c

Browse files
committed
Unify all DebugActions.
1 parent 7655475 commit 8e1790c

12 files changed

Lines changed: 144 additions & 68 deletions

src/plans/bundle_plan.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ mutable struct DebugWarnIfLagrangeMultiplierIncreases <: DebugAction
181181
return new(warn, Float64(Inf), tol)
182182
end
183183
end
184-
function show(io::IO, di::DebugWarnIfLagrangeMultiplierIncreases)
185-
return print(io, "DebugWarnIfLagrangeMultiplierIncreases(; tol=\"$(di.tol)\")")
184+
function show(io::IO, d::DebugWarnIfLagrangeMultiplierIncreases)
185+
m = (d.status === :No ? "" : ":$(d.status)")
186+
return print(io, "DebugWarnIfLagrangeMultiplierIncreases($(m); tol=\"$(d.tol)\")")
187+
end
188+
function status_summary(d::DebugWarnIfLagrangeMultiplierIncreases; context::Symbol = :default)
189+
(context === :short) && return repr(d)
190+
m = (d.status === :Once) ? "once" : (d.status === :No ? "(inactive)" : "")
191+
return "a DebugAction warning if the lagange multiplier increases in an iteration $m."
186192
end

src/plans/debug.jl

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,10 @@ mutable struct DebugChange{IR <: AbstractInverseRetractionMethod} <: DebugAction
299299
function DebugChange(
300300
M::AbstractManifold = DefaultManifold();
301301
storage::Union{Nothing, StoreStateAction} = nothing,
302-
io::IO = stdout,
303-
prefix::String = "Last Change: ",
304-
format::String = "$(prefix)%f",
305-
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(
306-
M
307-
),
302+
io::IO = stdout, prefix::String = "Last Change: ", format::String = "$(prefix)%f",
303+
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
308304
)
309305
irm = inverse_retraction_method
310-
# Deprecated, remove in Manopt 0.5
311306
if isnothing(storage)
312307
if M isa DefaultManifold
313308
storage = StoreStateAction(M; store_fields = [:Iterate])
@@ -324,9 +319,7 @@ function (d::DebugChange)(mp::AbstractManoptProblem, st::AbstractManoptSolverSta
324319
d.io,
325320
Printf.Format(d.format),
326321
distance(
327-
M,
328-
get_iterate(st),
329-
get_storage(d.storage, PointStorageKey(:Iterate)),
322+
M, get_iterate(st), get_storage(d.storage, PointStorageKey(:Iterate)),
330323
d.inverse_retraction_method,
331324
),
332325
)
@@ -444,6 +437,10 @@ end
444437
function show(io::IO, di::DebugEntry)
445438
return print(io, "DebugEntry(:$(di.field); format=\"$(escape_string(di.format))\", at_init=$(di.at_init))")
446439
end
440+
function status_summary(di::DebugEntry; context::Symbol = :default)
441+
(context === :short) && return "(:$(di.field), format=\"$(escape_string(di.format))\")"
442+
return "A DebugAction to print the field :$(di.field) of the solver style with format \"$(escape_string(di.format))\""
443+
end
447444

448445
"""
449446
DebugFeasibility <: DebugAction
@@ -574,10 +571,14 @@ function (d::DebugIfEntry)(::AbstractManoptProblem, st::AbstractManoptSolverStat
574571
end
575572
return nothing
576573
end
577-
function show(io::IO, di::DebugIfEntry)
578-
return print(io, "DebugIfEntry(:$(di.field), $(di.check); type=:$(di.type), at_init=$(di.at_init))")
574+
function show(io::IO, d::DebugIfEntry)
575+
return print(io, "DebugIfEntry(:$(d.field), $(d.check); type=:$(d.type), at_init=$(d.at_init))")
576+
end
577+
function status_summary(d::DebugIfEntry; context = :Default)
578+
(context === :short) && (return repr(d))
579+
# Inline and default
580+
return "a DebugAction printing the entry :$(d.field) of the solver state if $(d.check) of that field is true, in format “$(escape_string(d.msg))” as $(d.type)"
579581
end
580-
581582
@doc """
582583
DebugEntryChange{T} <: DebugAction
583584
@@ -645,6 +646,10 @@ function show(io::IO, dec::DebugEntryChange)
645646
"DebugEntryChange(:$(dec.field), $(dec.distance); format=\"$(escape_string(dec.format))\")",
646647
)
647648
end
649+
function status_summary(d::DebugEntryChange; context::Symbol = :default)
650+
(context === :short) && return repr(d)
651+
return "a DebugAction that prints the change of the entry :$(d.field) of the solver state in format “$(escape_string(di.format))"
652+
end
648653

649654
@doc """
650655
DebugGradientChange()
@@ -838,15 +843,16 @@ end
838843
show(io::IO, d::DebugMessages) = print(io, "DebugMessages(:$(d.mode), :$(d.status))")
839844
function status_summary(d::DebugMessages; context::Symbol = :default)
840845
if context === :short
841-
(d.mode == :Warning) && return "(:WarningMessages, :$(d.status))"
842-
(d.mode == :Error) && return "(:ErrorMessages, :$(d.status))"
843-
# default
844-
# (d.mode == :Info) && return "(:InfoMessages, $(d.status)"
845-
return "(:Messages, :$(d.status))"
846+
s = ":Messages"
847+
(d.mode == :Warning) && (s = ":WarningMessages")
848+
(d.mode == :Error) && (s = ":ErrorMessages")
849+
(d.mode == :Info) && (s = ":InfoMessages")
850+
return d.status === :No ? s : "($s, :$(d.status))"
846851
end
847852
# Inline and default
848853
m = "a $(d.mode == :Warning ? "warning " : (d.mode == :Error ? "error " : ""))message"
849-
return "a DebugAction printing messages collected during the last iteration as $m"
854+
s = d.status === :No ? " (inactive)" : (d.status === :Once ? " once" : "")
855+
return "a DebugAction printing messages collected during the last iteration as $(m)$(s)."
850856
end
851857

852858
@doc """
@@ -1087,8 +1093,14 @@ function (d::DebugWarnIfCostIncreases)(
10871093
end
10881094
return nothing
10891095
end
1090-
function show(io::IO, di::DebugWarnIfCostIncreases)
1091-
return print(io, "DebugWarnIfCostIncreases(; tol=\"$(di.tol)\")")
1096+
function show(io::IO, d::DebugWarnIfCostIncreases)
1097+
m = (d.status === :No ? "" : ":$(d.status)")
1098+
return print(io, "DebugWarnIfCostIncreases($(m); tol=\"$(d.tol)\")")
1099+
end
1100+
function status_summary(d::DebugWarnIfCostIncreases; context::Symbol = :default)
1101+
(context === :short) && return repr(d)
1102+
m = (d.status === :Once) ? "once" : (d.status === :No ? "(inactive)" : "")
1103+
return "a DebugAction warning if the cost increases in an iteration $m."
10921104
end
10931105

10941106
@doc """
@@ -1251,7 +1263,14 @@ function (d::DebugWarnIfGradientNormTooLarge)(
12511263
return nothing
12521264
end
12531265
function show(io::IO, d::DebugWarnIfGradientNormTooLarge)
1254-
return print(io, "DebugWarnIfGradientNormTooLarge($(d.factor), :$(d.status))")
1266+
# only print status if active
1267+
m = (d.status === :No ? "" : ", :$(d.status)")
1268+
return print(io, "DebugWarnIfGradientNormTooLarge($(d.factor)$(m))")
1269+
end
1270+
function status_summary(d::DebugWarnIfGradientNormTooLarge; context::Symbol = :default)
1271+
(context === :short) && return repr(d)
1272+
m = (d.status === :Once) ? " once" : (d.status === :No ? " (inactive)" : "")
1273+
return "a DebugAction warning if the gradient norm gets larger than the maximal stepsize$m."
12551274
end
12561275

12571276
@doc """
@@ -1278,9 +1297,6 @@ mutable struct DebugWarnIfStepsizeCollapsed{T} <: DebugAction
12781297
return new{T}(warn, tol)
12791298
end
12801299
end
1281-
function show(io::IO, di::DebugWarnIfStepsizeCollapsed)
1282-
return print(io, "DebugWarnIfStepsizeCollapsed($(di.stop_when_stepsize_less), :$(di.status))")
1283-
end
12841300
function (d::DebugWarnIfStepsizeCollapsed)(
12851301
amp::AbstractManoptProblem, st::AbstractManoptSolverState, k::Int
12861302
)
@@ -1296,7 +1312,15 @@ function (d::DebugWarnIfStepsizeCollapsed)(
12961312
end
12971313
return nothing
12981314
end
1299-
1315+
function show(io::IO, d::DebugWarnIfStepsizeCollapsed)
1316+
m = (d.status === :No ? "" : ", :$(d.status)")
1317+
return print(io, "DebugWarnIfStepsizeCollapsed($(d.stop_when_stepsize_less)$(m))")
1318+
end
1319+
function status_summary(d::DebugWarnIfStepsizeCollapsed; context::Symbol = :default)
1320+
(context === :short) && return repr(d)
1321+
m = (d.status === :Once) ? " once" : (d.status === :No ? " (inactive)" : "")
1322+
return "a DebugAction warning if the step size collapses (below $(d.stop_when_stepsize_less))$m."
1323+
end
13001324
#
13011325
# Convenience constructors using Symbols
13021326
#
@@ -1461,13 +1485,14 @@ Note that the Shortcut symbols should all start with a capital letter.
14611485
* `:Iterate` creates a [`DebugIterate`](@ref)
14621486
* `:Iteration` creates a [`DebugIteration`](@ref)
14631487
* `:IterativeTime` creates a [`DebugTime`](@ref)`(:Iterative)`
1488+
* `:ProxParameter` creates a [`DebugProximalParameter`](@ref)`()`
14641489
* `:Stepsize` creates a [`DebugStepsize`](@ref)
14651490
* `:Stop` creates a [`StoppingCriterion`](@ref)`()`
1491+
* `:Time` creates a [`DebugTime`](@ref)
14661492
* `:WarnStepsize` creates a [`DebugWarnIfStepsizeCollapsed`](@ref)
14671493
* `:WarnBundle` creates a [`DebugWarnIfLagrangeMultiplierIncreases`](@ref)
14681494
* `:WarnCost` creates a [`DebugWarnIfCostNotFinite`](@ref)
14691495
* `:WarnGradient` creates a [`DebugWarnIfFieldNotFinite`](@ref) for the `::Gradient`.
1470-
* `:Time` creates a [`DebugTime`](@ref)
14711496
* `:WarningMessages` creates a [`DebugMessages`](@ref)`(:Warning)`
14721497
* `:InfoMessages` creates a [`DebugMessages`](@ref)`(:Info)`
14731498
* `:ErrorMessages` creates a [`DebugMessages`](@ref)`(:Error)`
@@ -1484,6 +1509,7 @@ function DebugActionFactory(d::Symbol)
14841509
(d == :Iterate) && return DebugIterate()
14851510
(d == :Iteration) && return DebugIteration()
14861511
(d == :Feasibility) && return DebugFeasibility()
1512+
(d == :ProxParameter) && return DebugProximalParameter()
14871513
(d == :Stepsize) && return DebugStepsize()
14881514
(d == :Stop) && return DebugStoppingCriterion()
14891515
(d == :WarnStepsize) && return DebugWarnIfStepsizeCollapsed()
@@ -1516,6 +1542,7 @@ Note that the Shortcut symbols `t[1]` should all start with a capital letter.
15161542
* `:GradientNorm` creates a [`DebugGradientNorm`](@ref)
15171543
* `:Iterate` creates a [`DebugIterate`](@ref)
15181544
* `:Iteration` creates a [`DebugIteration`](@ref)
1545+
* `:ProxParameter` creates a [`DebugProximalParameter`](@ref)
15191546
* `:Stepsize` creates a [`DebugStepsize`](@ref)
15201547
* `:Stop` creates a [`DebugStoppingCriterion`](@ref)
15211548
* `:Time` creates a [`DebugTime`](@ref)
@@ -1533,11 +1560,12 @@ function DebugActionFactory(t::Tuple{Symbol, Any})
15331560
(t[1] == :Iteration) && return DebugIteration(; format = t[2])
15341561
(t[1] == :Iterate) && return DebugIterate(; format = t[2])
15351562
(t[1] == :IterativeTime) && return DebugTime(; mode = :Iterative, format = t[2])
1563+
(t[1] == :ProxParameter) && return DebugProximalParameter(; format = t[2])
15361564
(t[1] == :Stepsize) && return DebugStepsize(; format = t[2])
15371565
(t[1] == :Stop) && return DebugStoppingCriterion(t[2])
15381566
(t[1] == :Time) && return DebugTime(; format = t[2])
15391567
((t[1] == :Messages) || (t[1] == :InfoMessages)) && return DebugMessages(:Info, t[2])
15401568
(t[1] == :WarningMessages) && return DebugMessages(:Warning, t[2])
1541-
(t[1] == :ErrorMessages) && return DebugMessages(:error, t[2])
1569+
(t[1] == :ErrorMessages) && return DebugMessages(:Error, t[2])
15421570
return DebugEntry(t[1]; format = t[2])
15431571
end

src/plans/first_order_plan.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,9 +1184,12 @@ function (d::DebugStepsize)(
11841184
return nothing
11851185
end
11861186
function show(io::IO, ds::DebugStepsize)
1187-
return print(io, "DebugStepsize(; format=\"$(ds.format)\", at_init=$(ds.at_init))")
1187+
return print(io, "DebugStepsize(; format=\"$(escape_string(ds.format))\", at_init=$(ds.at_init))")
1188+
end
1189+
function status_summary(ds::DebugStepsize; context::Symbol = :default)
1190+
(context === :short) && return "(:Stepsize, \"$(escape_string(ds.format))\")"
1191+
return "A DebugAction that prints the current step size to $(ds.io) in format “$(escape_string(ds.format))"
11881192
end
1189-
status_summary(ds::DebugStepsize) = "(:Stepsize, \"$(ds.format)\")"
11901193
#
11911194
# Records
11921195
#

src/plans/mesh_adaptive_plan.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ function Base.show(io::IO, dmads::DefaultMeshAdaptiveDirectSearch)
356356
return print(io, ")")
357357
end
358358
function status_summary(dmads::DefaultMeshAdaptiveDirectSearch; context = :default)
359-
(context === :short) && repr(dmads)
359+
(context === :short) && return repr(dmads)
360360
(context === :inline) && "The default mesh adaptive direct search along a given direction using the $(dmads.retraction_method)"
361361
return """The default mesh adaptive direct search
362362
along one given direction X.
@@ -450,7 +450,7 @@ function Base.show(io::IO, mads::MeshAdaptiveDirectSearchState)
450450
return print(io, "stopping_criterion = ", mads.stop, ", poll = ", mads.poll, ", search = ", mads.search, ")")
451451
end
452452
function status_summary(mads::MeshAdaptiveDirectSearchState; context::Symbol = :default)
453-
(context === :short) && repr(mads)
453+
(context === :short) && return repr(mads)
454454
i = get_count(mads, :Iterations)
455455
conv_inl = (i > 0) ? (indicates_convergence(mads.stop) ? " (converged" : " (stopped") * " after $i iterations)" : ""
456456
(context === :inline) && return "A solver state for the trust region solver$(conv_inl)"

0 commit comments

Comments
 (0)