Skip to content

Commit 7bd3109

Browse files
committed
Finish the last gradient processors.
1 parent 66bfff5 commit 7bd3109

3 files changed

Lines changed: 36 additions & 4 deletions

File tree

src/plans/first_order_plan.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,11 @@ mutable struct PreconditionedDirectionRule{
10421042
} <: DirectionUpdateRule
10431043
preconditioner::F
10441044
direction::D
1045+
function PreconditionedDirectionRule(;
1046+
preconditioner::F, direction::D, evaluation::E
1047+
) where {E <: AbstractEvaluationType, D <: DirectionUpdateRule, F}
1048+
return new{E, D, F}(preconditioner, direction)
1049+
end
10451050
end
10461051
function PreconditionedDirectionRule(
10471052
M::AbstractManifold,
@@ -1050,7 +1055,7 @@ function PreconditionedDirectionRule(
10501055
evaluation::E = AllocatingEvaluation(),
10511056
) where {E <: AbstractEvaluationType, F}
10521057
dir = _produce_type(direction, M)
1053-
return PreconditionedDirectionRule{E, typeof(dir), F}(preconditioner, dir)
1058+
return PreconditionedDirectionRule(; preconditioner = preconditioner, direction = direction, evaluation = evaluation)
10541059
end
10551060
function (pg::PreconditionedDirectionRule{AllocatingEvaluation})(
10561061
mp::AbstractManoptProblem, s::AbstractGradientSolverState, k
@@ -1072,6 +1077,22 @@ function (pg::PreconditionedDirectionRule{InplaceEvaluation})(
10721077
pg.preconditioner(M, dir, p, dir)
10731078
return step, dir
10741079
end
1080+
function Base.show(io::IO, pg::PreconditionedDirectionRule{E}) where {E <: AbstractEvaluationType}
1081+
print(io, "PreconditionedDirectionRule(; direction = ", pg.direction, ", preconditioner = ", pg.preconditioner, ", ", _to_kw(E))
1082+
return print(io, ")")
1083+
end
1084+
function status_summary(pg::PreconditionedDirectionRule; context::Symbol = :default)
1085+
(context === :short) && return repr(pg)
1086+
(context === :inline) && return "A preconditioner gradient processor"
1087+
return """
1088+
Preconditioned Direction Rule
1089+
1090+
## Parameters
1091+
preconditioner: $(_MANOPT_INDENT)$(nr.μ)
1092+
## Direction Rule
1093+
$(_in_str(status_summary(pg.direction; context = context); indent = 1, headers = 1))
1094+
"""
1095+
end
10751096

10761097
"""
10771098
PreconditionedDirection(preconditioner; kwargs...)

src/solvers/alternating_gradient_descent.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ function AlternatingGradientRule(
2727
) where {T}
2828
return AlternatingGradientRule{T}(X)
2929
end
30-
30+
function Base.show(io::IO, ag::AlternatingGradientRule)
31+
return print(io, "AlternatingGradientRule($(ag.X)")
32+
end
33+
function status_summary(ag::AlternatingGradientRule; context::Symbol = :default)
34+
(context === :short) && return repr(ag)
35+
return "A alternating gradient processor"
36+
end
3137
"""
3238
AlternatingGradientDescentState <: AbstractGradientDescentSolverState
3339

src/solvers/stochastic_gradient_descent.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ function StochasticGradientRule(
141141
) where {T}
142142
return StochasticGradientRule{T}(X)
143143
end
144-
145144
function (sg::StochasticGradientRule)(
146145
apm::AbstractManoptProblem, sgds::StochasticGradientDescentState, k
147146
)
@@ -151,7 +150,13 @@ function (sg::StochasticGradientRule)(
151150
j = sgds.order_type == :Random ? rand(1:length(sgds.order)) : sgds.order[sgds.k]
152151
return sgds.stepsize(apm, sgds, k), get_gradient!(apm, sg.X, sgds.p, j)
153152
end
154-
153+
function Base.show(io::IO, sg::StochasticGradientRule)
154+
return print(io, "StochasticGradientRule($(sg.X)")
155+
end
156+
function status_summary(sg::StochasticGradientRule; context::Symbol = :default)
157+
(context === :short) && return repr(sg)
158+
return "A stochastic gradient processor"
159+
end
155160
@doc """
156161
StochasticGradient(; kwargs...)
157162
StochasticGradient(M::AbstractManifold; kwargs...)

0 commit comments

Comments
 (0)