Skip to content

Commit 737ce28

Browse files
committed
Finish RecordAction refactor and aløready add a bit of testing,
1 parent 9dce89f commit 737ce28

5 files changed

Lines changed: 43 additions & 30 deletions

File tree

_typos.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ methodes = "methodes" # french
33
Serie = "Serie" # french
44
sur = "sur" # french
55
cmo = "cmo" # often used abbreviation for constrained manifold objective
6-
6+
nd = "nd" # like in 2nd
77
[files]
88
extend-exclude = [
99
"tutorials/*.html",

src/plans/first_order_plan.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ record the norm of the current gradient
12281228
"""
12291229
mutable struct RecordGradientNorm{R <: Real} <: RecordAction
12301230
recorded_values::Array{R, 1}
1231-
RecordGradientNorm(r::Type{<:Real}) = new{r}(Array{r, 1}())
1231+
RecordGradientNorm(r::Type{<:Real} = Float64) = new{r}(Array{r, 1}())
12321232
end
12331233
function (r::RecordGradientNorm)(
12341234
mp::AbstractManoptProblem, ast::AbstractManoptSolverState, k::Int

src/plans/proximal_plan.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ record the current iterates proximal point algorithm parameter given by in
343343
"""
344344
mutable struct RecordProximalParameter{R <: Real} <: RecordAction
345345
recorded_values::Array{R, 1}
346-
RecordProximalParameter(r::Type{<:Real}) = new{r}(Array{r, 1}())
346+
RecordProximalParameter(r::Type{<:Real} = Float64) = new{r}(Array{r, 1}())
347347
end
348348
function (r::RecordProximalParameter)(
349349
::AbstractManoptProblem, cpps::CyclicProximalPointState, k::Int

src/plans/record.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ function status_summary(re::RecordEvery; context::Symbol = :default)
267267
if context === :short
268268
s = ""
269269
if re.record isa RecordGroup
270-
s = status_summary(re.record)[3:(end - 2)]
270+
s = status_summary(re.record; context = context)[2:(end - 1)]
271271
else
272-
s = "$(re.record)"
272+
s = "$(status_summary(re.record; context = context))"
273273
end
274274
return "[$s, $(re.every)]"
275275
end
@@ -280,8 +280,8 @@ function status_summary(re::RecordEvery; context::Symbol = :default)
280280
(re.every == 1) && (s = "every")
281281
(context === :inline) && return "A RecordAction that records its inner action $s iteration"
282282
return """
283-
A RecordAction that records $s iteration with\n
284-
$(_in_str(status_summary(re.record; context = context); indent = 1))
283+
A RecordAction that records $s iteration with
284+
$(_MANOPT_INDENT)$(_in_str(status_summary(re.record; context = context); indent = 1))
285285
"""
286286
end
287287
get_record(r::RecordEvery) = get_record(r.record)
@@ -361,7 +361,7 @@ function (d::RecordGroup)(p::AbstractManoptProblem, s::AbstractManoptSolverState
361361
return
362362
end
363363
function status_summary(rg::RecordGroup; context::Symbol = :default)
364-
(context === :short) && (return "[ $(join(["$(status_summary(ri))" for ri in rg.group], ", ")) ]")
364+
(context === :short) && (return "[$(join(["$(status_summary(ri; context = context))" for ri in rg.group], ", "))]")
365365
(context === :inline) && (return "A group of $(length(rg.group)) RecordActions")
366366
return "A group of $(length(rg.group)) RecordActions:\n $(join(["* $(status_summary(ri; context = context))" for ri in rg.group], "\n"))\n"
367367
end
@@ -444,24 +444,25 @@ end
444444
function show(io::IO, rsr::RecordSubsolver{R}) where {R}
445445
return print(io, "RecordSubsolver(; record=$(rsr.record), record_type=$R)")
446446
end
447-
function status_summary(::RecordSubsolver{R}; context::Symbol = :default) where {R}
447+
function status_summary(rsr::RecordSubsolver{R}; context::Symbol = :default) where {R}
448448
(context === :short) && return ":Subsolver"
449-
(context === :inline) && return "A RecordAction to specify something to record from each subolver run"
450-
return
451-
"""
452-
A RecordAction to record elements of type $R in from each subsolver run
449+
(context === :inline) && return "A RecordAction to specify something to record from each subsolver run"
450+
return """
451+
A RecordAction to record elements in from each subsolver run of type $R.
453452
453+
## Recorded values
454+
The following recorded symbols from the sub state are recorded in every iteration of the (outer) solver
455+
$(join([ " * :$(s)" for s in rsr.record], "\n"))
454456
"""
455457
end
456458

457459
@doc """
458460
RecordWhenActive <: RecordAction
459461
460462
record action that only records if the `active` boolean is set to true.
461-
This can be set from outside and is for example triggered by |`RecordEvery`](@ref)
462-
on recordings of the subsolver.
463-
While this is for sub solvers maybe not completely necessary, recording values that
464-
are never accessible, is not that useful.
463+
This can be set from outside and is for example triggered by [`RecordEvery`](@ref)
464+
on recordings of a subsolver. While this is for sub solvers maybe not completely necessary,
465+
recording values that are never accessible, is not that useful.
465466
466467
# Fields
467468
@@ -482,7 +483,6 @@ mutable struct RecordWhenActive{R <: RecordAction} <: RecordAction
482483
return new{R}(r, active, always_update)
483484
end
484485
end
485-
486486
function (rwa::RecordWhenActive)(
487487
amp::AbstractManoptProblem, ams::AbstractManoptSolverState, k::Int
488488
)
@@ -495,8 +495,13 @@ end
495495
function show(io::IO, rwa::RecordWhenActive)
496496
return print(io, "RecordWhenActive($(rwa.record), $(rwa.active), $(rwa.always_update))")
497497
end
498-
function status_summary(rwa::RecordWhenActive)
499-
return repr(rwa)
498+
function status_summary(rwa::RecordWhenActive; context::Symbol = :default)
499+
(context === :short) && (return repr(rwa))
500+
(context === :inline) && (return "A RecordAction that only records its inner action when active (currently: $(rwa.active ? "" : "in")active)")
501+
return """
502+
Record the following only, when active (currently: $(rwa.active ? "" : "in")active)
503+
$(_in_str(status_summary(rwa.record; context = context), indent = 1, headers = 0))
504+
"""
500505
end
501506
function set_parameter!(rwa::RecordWhenActive, v::Val, args...)
502507
set_parameter!(rwa.record, v, args...)
@@ -615,7 +620,7 @@ end
615620
show(io::IO, ::RecordCost) = print(io, "RecordCost()")
616621
function status_summary(::RecordCost; context::Symbol = :default)
617622
(context === :short) && return ":Cost"
618-
return "A RecordAction to record the cost value."
623+
return "A RecordAction to record the cost value"
619624
end
620625

621626
@doc """

test/plans/test_record.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
2424
dmp = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f))
2525
a = RecordIteration()
2626
@test repr(a) == "RecordIteration()"
27-
@test Manopt.status_summary(a) == ":Iteration"
27+
@test Manopt.status_summary(a; context = :short) == ":Iteration"
28+
@test Manopt.status_summary(a) == "A RecordAction to record the current iteration number"
2829
# constructors
2930
rs = RecordSolverState(gds, a)
3031
Manopt.set_parameter!(rs, :Record, RecordCost())
@@ -85,7 +86,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
8586
@test_throws ErrorException RecordGroup(RecordAction[], Dict(:a => 1))
8687
@test_throws ErrorException RecordGroup(RecordAction[], Dict(:a => 0))
8788
b = RecordGroup([RecordIteration(), RecordIteration()], Dict(:It1 => 1, :It2 => 2))
88-
@test Manopt.status_summary(b) == "[ :Iteration, :Iteration ]"
89+
@test Manopt.status_summary(b; context = :short) == "[:Iteration, :Iteration]"
90+
@test startswith(Manopt.status_summary(b), "A group of 2 RecordActions:\n")
8991
@test repr(b) == "RecordGroup([RecordIteration(), RecordIteration()])"
9092
b(dmp, gds, 1)
9193
b(dmp, gds, 2)
@@ -102,7 +104,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
102104
@testset "RecordEvery" begin
103105
c = RecordEvery(a, 10, true)
104106
@test repr(c) == "RecordEvery(RecordIteration(), 10, true)"
105-
@test Manopt.status_summary(c) == "[RecordIteration(), 10]"
107+
@test Manopt.status_summary(c; context = :short) == "[:Iteration, 10]"
108+
@test startswith(Manopt.status_summary(c), "A RecordAction that records every 10th iteration with\n")
106109
c(dmp, gds, 0)
107110
@test length(get_record(c)) === 0
108111
c(dmp, gds, 1)
@@ -118,7 +121,7 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
118121
10,
119122
)
120123
@test repr(c2) == "RecordEvery($(repr(c2.record)), 10, true)"
121-
@test Manopt.status_summary(c2) == "[:Iteration, :Iteration, 10]"
124+
@test Manopt.status_summary(c2; context = :short) == "[:Iteration, :Iteration, 10]"
122125
c2(dmp, gds, 5)
123126
c2(dmp, gds, 10)
124127
c2(dmp, gds, 20)
@@ -129,7 +132,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
129132
d = RecordChange()
130133
sd = "RecordChange(; inverse_retraction_method=LogarithmicInverseRetraction())"
131134
@test repr(d) == sd
132-
@test Manopt.status_summary(d) == ":Change"
135+
@test Manopt.status_summary(d; context = :short) == ":Change"
136+
@test startswith(Manopt.status_summary(d), "A RecordAction to record the change of the iterate")
133137
d(dmp, gds, 1)
134138
@test d.recorded_values == [0.0] # no p0 -> assume p is the first iterate
135139
set_iterate!(gds, M, p + [1.0, 0.0])
@@ -169,7 +173,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
169173
@testset "RecordIterate" begin
170174
set_iterate!(gds, M, p)
171175
f = RecordIterate(p)
172-
@test Manopt.status_summary(f) == ":Iterate"
176+
@test Manopt.status_summary(f; context = :short) == ":Iterate"
177+
@test Manopt.status_summary(f) == "A RecordAction to record the current iterate"
173178
@test repr(f) == "RecordIterate(Vector{Float64})"
174179
@test_throws ErrorException RecordIterate()
175180
f(dmp, gds, 1)
@@ -178,7 +183,8 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
178183
@testset "RecordCost" begin
179184
g = RecordCost()
180185
@test repr(g) == "RecordCost()"
181-
@test Manopt.status_summary(g) == ":Cost"
186+
@test Manopt.status_summary(g; context = :short) == ":Cost"
187+
@test Manopt.status_summary(g) == "A RecordAction to record the cost value"
182188
g(dmp, gds, 1)
183189
@test g.recorded_values == [0.0]
184190
gds.p = [3.0, 2.0]
@@ -198,15 +204,17 @@ Manopt.get_parameter(d::TestRecordParameterState, ::Val{:value}) = d.value
198204
@testset "RecordSubsolver" begin
199205
rss = RecordSubsolver()
200206
@test repr(rss) == "RecordSubsolver(; record=[:Iteration], record_type=Any)"
201-
@test Manopt.status_summary(rss) == ":Subsolver"
207+
@test Manopt.status_summary(rss; context = :short) == ":Subsolver"
208+
@test startswith(Manopt.status_summary(rss), "A RecordAction to record elements in from each subsolver")
202209
epms = ExactPenaltyMethodState(M, dmp, rs)
203210
rss(dmp, epms, 1)
204211
end
205212
@testset "RecordWhenActive" begin
206213
i = RecordIteration()
207214
rwa = RecordWhenActive(i)
208215
@test repr(rwa) == "RecordWhenActive(RecordIteration(), true, true)"
209-
@test Manopt.status_summary(rwa) == repr(rwa)
216+
@test Manopt.status_summary(rwa; context = :short) == repr(rwa)
217+
@test startswith(Manopt.status_summary(rwa), "Record the following only, when active")
210218
rwa(dmp, gds, 1)
211219
@test length(get_record(rwa)) == 1
212220
rwa(dmp, gds, -1) # Reset

0 commit comments

Comments
 (0)