Skip to content

Commit 351734e

Browse files
committed
a bit of code formatting and one test.
1 parent cb45f45 commit 351734e

1 file changed

Lines changed: 12 additions & 44 deletions

File tree

test/solvers/test_gradient_descent.jl

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ using ManifoldDiff: grad_distance
2424
500,
2525
)
2626
s = gradient_descent(
27-
M,
28-
f,
29-
grad_f,
30-
data[1];
27+
M, f, grad_f, data[1];
3128
stopping_criterion = StopAfterIteration(200) | StopWhenChangeLess(M, 1.0e-16),
3229
stepsize = ArmijoLinesearch(; contraction_factor = 0.99),
3330
debug = d,
@@ -38,10 +35,7 @@ using ManifoldDiff: grad_distance
3835
res_debug = String(take!(my_io))
3936
@test res_debug === " f(x): 1.357071\n"
4037
p2 = gradient_descent(
41-
M,
42-
f,
43-
grad_f,
44-
data[1];
38+
M, f, grad_f, data[1];
4539
stopping_criterion = StopAfterIteration(200) | StopWhenChangeLess(M, 1.0e-16),
4640
stepsize = ArmijoLinesearch(; contraction_factor = 0.99),
4741
)
@@ -56,10 +50,7 @@ using ManifoldDiff: grad_distance
5650
stop_when_stepsize_exceeds = 0.9 * π,
5751
)
5852
p3 = gradient_descent(
59-
M,
60-
f,
61-
grad_f,
62-
data[1];
53+
M, f, grad_f, data[1];
6354
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
6455
stepsize = step,
6556
debug = [], # do not warn about increasing step here
@@ -75,10 +66,7 @@ using ManifoldDiff: grad_distance
7566
stop_when_stepsize_exceeds = 0.9 * π,
7667
)
7768
p4 = gradient_descent(
78-
M,
79-
f,
80-
grad_f,
81-
data[1];
69+
M, f, grad_f, data[1];
8270
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
8371
stepsize = step2,
8472
debug = [], # do not warn about increasing step here
@@ -94,40 +82,28 @@ using ManifoldDiff: grad_distance
9482
stop_when_stepsize_exceeds = 0.9 * π,
9583
)
9684
p5 = gradient_descent(
97-
M,
98-
f,
99-
grad_f,
100-
data[1];
85+
M, f, grad_f, data[1];
10186
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
10287
stepsize = step3,
10388
debug = [], # do not warn about increasing step here
10489
)
10590
@test isapprox(M, p, p5; atol = 1.0e-13)
10691
p6 = gradient_descent(
107-
M,
108-
f,
109-
grad_f,
110-
data[1];
92+
M, f, grad_f, data[1];
11193
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
11294
direction = Nesterov(; p = copy(M, data[1])),
11395
)
11496
@test isapprox(M, p, p6; atol = 1.0e-13)
11597
# Precon in simple scale down by 2
11698
p7 = gradient_descent(
117-
M,
118-
f,
119-
grad_f,
120-
data[1];
99+
M, f, grad_f, data[1];
121100
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
122101
direction = PreconditionedDirection((M, p, X) -> 0.5 .* X),
123102
)
124103
@test isapprox(M, p, p7; atol = 1.0e-13)
125104
# Precon in simple scale down by 2 – inplace
126105
p8 = gradient_descent(
127-
M,
128-
f,
129-
grad_f,
130-
data[1];
106+
M, f, grad_f, data[1];
131107
stopping_criterion = StopAfterIteration(1000) | StopWhenChangeLess(M, 1.0e-16),
132108
direction = PreconditionedDirection(
133109
(M, Y, p, X) -> (Y .= 0.5 .* X); evaluation = InplaceEvaluation()
@@ -170,20 +146,14 @@ using ManifoldDiff: grad_distance
170146
# `gradient_descent` allocated n2 newly
171147
@test isapprox(M, north, n2a)
172148
n3 = gradient_descent(
173-
M,
174-
f,
175-
grad_f,
176-
pts[1];
149+
M, f, grad_f, pts[1];
177150
direction = MomentumGradient(),
178151
stepsize = ConstantLength(),
179152
debug = [], # do not warn about increasing step here
180153
)
181154
@test isapprox(M, north, n3)
182155
n4 = gradient_descent(
183-
M,
184-
f,
185-
grad_f,
186-
pts[1];
156+
M, f, grad_f, pts[1];
187157
direction = AverageGradient(M; n = 5),
188158
stopping_criterion = StopAfterIteration(800),
189159
)
@@ -194,14 +164,12 @@ using ManifoldDiff: grad_distance
194164
@test startswith(repr(r), "# Solver state for `Manopt.jl`s Gradient Descent")
195165
# State and a count objective, putting stats behind print
196166
n6 = gradient_descent(
197-
M,
198-
f,
199-
grad_f,
200-
pts[1];
167+
M, f, grad_f, pts[1];
201168
count = [:Gradient],
202169
return_objective = true,
203170
return_state = true,
204171
)
172+
@test stopped_at(n6) > 0
205173
@test repr(n6) == "$(n6[2])\n\n$(n6[1])"
206174
end
207175
@testset "Tutorial mode" begin

0 commit comments

Comments
 (0)