Skip to content

Commit b418aa3

Browse files
committed
A bit of code structuring and test the new scheme with GradientDescentState.
1 parent 799cce8 commit b418aa3

5 files changed

Lines changed: 59 additions & 39 deletions

File tree

src/plans/objective.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ the type `T` indicates the global [`AbstractEvaluationType`](@ref).
2222
"""
2323
abstract type AbstractManifoldObjective{E <: AbstractEvaluationType} end
2424

25+
function Base.show(io::IO, ::MIME"text/plain", amo::AbstractManifoldObjective)
26+
multiline = get(io, :multiline, true)
27+
if multiline
28+
return status_summary(io, amo)
29+
else
30+
show(io, amo)
31+
end
32+
end
33+
2534
@doc """
2635
AbstractDecoratedManifoldObjective{E<:AbstractEvaluationType,O<:AbstractManifoldObjective}
2736

src/plans/plan.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""
2+
status_summary(io, e)
23
status_summary(e)
34
4-
Return a string reporting about the current status of `e`,
5-
where `e` is a type from Manopt.
5+
Returns a string reporting about the current status of an element `e`
6+
defined in `Manopt.jl`, which can also directly be printed to an `IO` stream `io`.
67
7-
This method is similar to `show` but just returns a string.
8-
It might also be more verbose in explaining, or hide internal information.
8+
This method should generate a human readable summary of `e`,
9+
10+
By default, the variant with an `IO` stream dispatches to the one without to generate
11+
a string and prints it to the `IO` stream.
912
"""
10-
function status_summary(e)
11-
a = IOBuffer()
12-
Base.show(a, MIME"text/plain"(), e)
13-
return String(take!(a))
13+
function status_summary end
14+
15+
function status_summary(io::IO, e)
16+
return print(io, status_summary(e))
1417
end
1518

1619
"""

src/plans/problem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ Usually the cost should be within an [`AbstractManifoldObjective`](@ref).
1717
"""
1818
abstract type AbstractManoptProblem{M <: AbstractManifold} end
1919

20+
function Base.show(io::IO, ::MIME"text/plain", amp::AbstractManoptProblem)
21+
multiline = get(io, :multiline, true)
22+
if multiline
23+
return status_summary(io, amp)
24+
else
25+
show(io, amp)
26+
end
27+
end
28+
29+
2030
@doc """
2131
DefaultManoptProblem{TM <: AbstractManifold, Objective <: AbstractManifoldObjective}
2232

src/plans/solver_state.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ $(_fields(:stopping_criterion; name = "stop"))
1515
"""
1616
abstract type AbstractManoptSolverState end
1717

18+
function Base.show(io::IO, ::MIME"text/plain", ams::AbstractManoptSolverState)
19+
multiline = get(io, :multiline, true)
20+
if multiline
21+
return status_summary(io, ams)
22+
else
23+
show(io, ams)
24+
end
25+
end
26+
1827
"""
1928
ClosedFormSubSolverState{E<:AbstractEvaluationType} <: AbstractManoptSolverState
2029

src/solvers/gradient_descent.jl

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ mutable struct GradientDescentState{
5252
retraction_method::TRTM
5353
end
5454
function GradientDescentState(
55-
M::AbstractManifold;
55+
M::AbstractManifold = ManifoldsBase.DefaultManifold();
5656
p::P = rand(M),
5757
X::T = zero_vector(M, p),
5858
stopping_criterion::SC = StopAfterIteration(200) | StopWhenGradientNormLess(1.0e-8),
@@ -96,38 +96,27 @@ function get_message(gds::GradientDescentState)
9696
return get_message(gds.stepsize)
9797
end
9898

99-
function Base.show(io::IO, obj::GradientDescentState)
100-
return print_object(io, obj, multiline = false)
101-
end
102-
# the 3-argument show used by display(obj) on the REPL
103-
function Base.show(io::IO, ::MIME"text/plain", obj::GradientDescentState)
104-
multiline = get(io, :multiline, true)
105-
return print_object(io, obj, multiline = multiline)
99+
function Base.show(io::IO, gds::GradientDescentState)
100+
return "GradientDescentState(; direction=$(repr(gds.direction)), p=$(repr(gds.p)), stepsize=$(repr(gds.stepsize)), stopping_criterion=$(repr(gds.stop)), retraction_method=$(repr(gds.retraction_method)), X=$(repr(gds.X)))"
106101
end
107102

108-
function print_object(io::IO, gds::GradientDescentState; multiline::Bool)
109-
if multiline
110-
i = get_count(gds, :Iterations)
111-
Iter = (i > 0) ? "After $i iterations\n" : ""
112-
Conv = indicates_convergence(gds.stop) ? "Yes" : "No"
113-
s = """
114-
# Solver state for `Manopt.jl`s Gradient Descent
115-
$Iter
116-
## Parameters
117-
* retraction method: $(gds.retraction_method)
118-
119-
## Stepsize
120-
$(gds.stepsize)
121-
122-
## Stopping criterion
123-
124-
$(status_summary(gds.stop))
125-
This indicates convergence: $Conv"""
126-
return print(io, s)
127-
else
128-
# write something short, or go back to default mode
129-
return Base.show_default(io, gds)
130-
end
103+
function status_summary(gds::GradientDescentState)
104+
i = get_count(gds, :Iterations)
105+
Iter = (i > 0) ? "After $i iterations\n" : ""
106+
Conv = indicates_convergence(gds.stop) ? "Yes" : "No"
107+
s = """
108+
# Solver state for `Manopt.jl`s Gradient Descent
109+
$Iter
110+
## Parameters
111+
* retraction method: $(gds.retraction_method)
112+
113+
## Stepsize
114+
$(gds.stepsize)
115+
116+
## Stopping criterion
117+
$(status_summary(gds.stop))
118+
This indicates convergence: $Conv"""
119+
return s
131120
end
132121

133122
_doc_gd_iterate = raw"""

0 commit comments

Comments
 (0)