Skip to content

Commit 9dce89f

Browse files
committed
Finish further Record Actions
1 parent 53dd9d9 commit 9dce89f

3 files changed

Lines changed: 75 additions & 27 deletions

File tree

src/plans/first_order_plan.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,15 +1214,21 @@ function (r::RecordGradient{T})(
12141214
return record_or_reset!(r, get_gradient(s), k)
12151215
end
12161216
show(io::IO, ::RecordGradient{T}) where {T} = print(io, "RecordGradient($T)")
1217-
1217+
function status_summary(rg::RecordGradient; context::Symbol = :default)
1218+
(context === :short) && return ":Gradient"
1219+
return "A RecordAction to record the current gradient"
1220+
end
12181221
@doc """
1219-
RecordGradientNorm <: RecordAction
1222+
RecordGradientNorm{R<:Real} <: RecordAction
12201223
12211224
record the norm of the current gradient
1225+
1226+
## Constructor
1227+
RecordGradientNorm(r::Type{<:Real}=Float64)
12221228
"""
1223-
mutable struct RecordGradientNorm <: RecordAction
1224-
recorded_values::Array{Float64, 1}
1225-
RecordGradientNorm() = new(Array{Float64, 1}())
1229+
mutable struct RecordGradientNorm{R <: Real} <: RecordAction
1230+
recorded_values::Array{R, 1}
1231+
RecordGradientNorm(r::Type{<:Real}) = new{r}(Array{r, 1}())
12261232
end
12271233
function (r::RecordGradientNorm)(
12281234
mp::AbstractManoptProblem, ast::AbstractManoptSolverState, k::Int
@@ -1231,16 +1237,28 @@ function (r::RecordGradientNorm)(
12311237
return record_or_reset!(r, norm(M, get_iterate(ast), get_gradient(ast)), k)
12321238
end
12331239
show(io::IO, ::RecordGradientNorm) = print(io, "RecordGradientNorm()")
1240+
function status_summary(rg::RecordGradientNorm; context::Symbol = :default)
1241+
(context === :short) && return ":GradientNorm"
1242+
return "A RecordAction to record the current gradient norm"
1243+
end
12341244

12351245
@doc """
12361246
RecordStepsize <: RecordAction
12371247
1238-
record the step size
1248+
record the step size.
1249+
1250+
## Constructor
1251+
RecordStepsise(r::Type{<:Real}=Float64)
12391252
"""
1240-
mutable struct RecordStepsize <: RecordAction
1241-
recorded_values::Array{Float64, 1}
1242-
RecordStepsize() = new(Array{Float64, 1}())
1253+
mutable struct RecordStepsize{R <: Real} <: RecordAction
1254+
recorded_values::Array{R, 1}
1255+
RecordStepsize(r::Type{<:Real} = Float64) = new{r}(Array{r, 1}())
12431256
end
12441257
function (r::RecordStepsize)(p::AbstractManoptProblem, s::AbstractGradientSolverState, k)
12451258
return record_or_reset!(r, get_last_stepsize(p, s, k), k)
12461259
end
1260+
show(io::IO, ::RecordStepsize{R}) where {R} = print(io, "RecordStepsize($R)")
1261+
function status_summary(rg::RecordStepsize{R}; context::Symbol = :default) where {R}
1262+
(context === :short) && return ":Stepsize"
1263+
return "A RecordAction to record the current stepsize (of type $R)"
1264+
end

src/plans/proximal_plan.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,17 +333,25 @@ end
333333
#
334334
# Record
335335
@doc """
336-
RecordProximalParameter <: RecordAction
336+
RecordProximalParameter{R <: Real} <: RecordAction
337337
338338
record the current iterates proximal point algorithm parameter given by in
339339
[`AbstractManoptSolverState`](@ref)s `o.λ`.
340+
341+
## Constructor
342+
RecordProximalParameter(r::Type{<:Real}=Float64)
340343
"""
341-
mutable struct RecordProximalParameter <: RecordAction
342-
recorded_values::Array{Float64, 1}
343-
RecordProximalParameter() = new(Array{Float64, 1}())
344+
mutable struct RecordProximalParameter{R <: Real} <: RecordAction
345+
recorded_values::Array{R, 1}
346+
RecordProximalParameter(r::Type{<:Real}) = new{r}(Array{r, 1}())
344347
end
345348
function (r::RecordProximalParameter)(
346349
::AbstractManoptProblem, cpps::CyclicProximalPointState, k::Int
347350
)
348351
return record_or_reset!(r, cpps.λ(k), k)
349352
end
353+
show(io::IO, ::RecordProximalParameter{R}) where {R} = print(io, "RecordProximalParameter($R)")
354+
function status_summary(rg::RecordProximalParameter{R}; context::Symbol = :default) where {R}
355+
(context === :short) && return ":ProximalParameter"
356+
return "A RecordAction to record the current proximal parameter (of type $R)"
357+
end

src/plans/record.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,25 @@ function show(io::IO, re::RecordEvery)
264264
return print(io, "RecordEvery($(re.record), $(re.every), $(re.always_update))")
265265
end
266266
function status_summary(re::RecordEvery; context::Symbol = :default)
267-
if context
268-
s = ""
269-
if re.record isa RecordGroup
270-
s = status_summary(re.record)[3:(end - 2)]
271-
else
272-
s = "$(re.record)"
267+
if context === :short
268+
s = ""
269+
if re.record isa RecordGroup
270+
s = status_summary(re.record)[3:(end - 2)]
271+
else
272+
s = "$(re.record)"
273+
end
274+
return "[$s, $(re.every)]"
273275
end
274-
return "[$s, $(re.every)]"
276+
s = ""
277+
(re.every % 10 == 2) && (s = "every $(re.every)nd")
278+
(re.every % 10 == 3) && (s = "every $(re.every)rd")
279+
(re.every % 10 [2, 3]) && (s = "every $(re.every)th")
280+
(re.every == 1) && (s = "every")
281+
(context === :inline) && return "A RecordAction that records its inner action $s iteration"
282+
return """
283+
A RecordAction that records $s iteration with\n
284+
$(_in_str(status_summary(re.record; context = context); indent = 1))
285+
"""
275286
end
276287
get_record(r::RecordEvery) = get_record(r.record)
277288
get_record(r::RecordEvery, k) = get_record(r.record, k)
@@ -352,7 +363,7 @@ end
352363
function status_summary(rg::RecordGroup; context::Symbol = :default)
353364
(context === :short) && (return "[ $(join(["$(status_summary(ri))" for ri in rg.group], ", ")) ]")
354365
(context === :inline) && (return "A group of $(length(rg.group)) RecordActions")
355-
return "A group of $(length(rg.group)) RecordActions:\n $(join( ["* $(status_summary(ri; context=context))" for ri in rg.group], "\n"))\n"
366+
return "A group of $(length(rg.group)) RecordActions:\n $(join(["* $(status_summary(ri; context = context))" for ri in rg.group], "\n"))\n"
356367
end
357368
function show(io::IO, rg::RecordGroup)
358369
s = join(["$(ri)" for ri in rg.group], ", ")
@@ -433,7 +444,15 @@ end
433444
function show(io::IO, rsr::RecordSubsolver{R}) where {R}
434445
return print(io, "RecordSubsolver(; record=$(rsr.record), record_type=$R)")
435446
end
436-
status_summary(::RecordSubsolver) = ":Subsolver"
447+
function status_summary(::RecordSubsolver{R}; context::Symbol = :default) where {R}
448+
(context === :short) && return ":Subsolver"
449+
(context === :inline) && return "A RecordAction to specify something to record from each subolver run"
450+
return
451+
"""
452+
A RecordAction to record elements of type $R in from each subsolver run
453+
454+
"""
455+
end
437456

438457
@doc """
439458
RecordWhenActive <: RecordAction
@@ -736,8 +755,10 @@ function (r::RecordIteration)(::AbstractManoptProblem, ::AbstractManoptSolverSta
736755
return record_or_reset!(r, k, k)
737756
end
738757
show(io::IO, ::RecordIteration) = print(io, "RecordIteration()")
739-
status_summary(::RecordIteration) = ":Iteration"
740-
758+
function status_summary(::RecordIteration; context::Symbol = :default)
759+
(context === :short) && return ":Iteration"
760+
return "A RecordAction to record the current iteration number"
761+
end
741762
@doc """
742763
RecordStoppingReason <: RecordAction
743764
@@ -929,15 +950,15 @@ create a [`RecordAction`](@ref) where
929950
* a [`RecordAction`](@ref) is passed through
930951
* a [`Symbol`] creates
931952
* `:Change` to record the change of the iterates, see [`RecordChange`](@ref)
953+
* `:Cost` to record the current cost function value
932954
* `:Gradient` to record the gradient, see [`RecordGradient`](@ref)
933955
* `:GradientNorm to record the norm of the gradient, see [`RecordGradientNorm`](@ref)
934956
* `:Iterate` to record the iterate
935957
* `:Iteration` to record the current iteration number
936-
* `IterativeTime` to record the time iteratively
937-
* `:Cost` to record the current cost function value
958+
* `:IterativeTime` to record the times taken for each iteration.
959+
* `:ProximalParameter` to record the proximal parameter, see [`RecordProximalParameter`](@ref)
938960
* `:Stepsize` to record the current step size
939961
* `:Time` to record the total time taken after every iteration
940-
* `:IterativeTime` to record the times taken for each iteration.
941962
942963
and every other symbol is passed to [`RecordEntry`](@ref), which results in recording the
943964
field of the state with the symbol indicating the field of the solver to record.
@@ -952,6 +973,7 @@ function RecordActionFactory(s::AbstractManoptSolverState, symbol::Symbol)
952973
(symbol == :Iterate) && return RecordIterate(get_iterate(s))
953974
(symbol == :Iteration) && return RecordIteration()
954975
(symbol == :IterativeTime) && return RecordTime(; mode = :iterative)
976+
(symbol == :ProximalParameter) && return RecordProximalParameter()
955977
(symbol == :Stepsize) && return RecordStepsize()
956978
(symbol == :Stop) && return RecordStoppingReason()
957979
(symbol == :Subsolver) && return RecordSubsolver()

0 commit comments

Comments
 (0)