Skip to content

Commit 99efe60

Browse files
authored
minor improvements for scaled objective (#450)
1 parent 2c7ec68 commit 99efe60

3 files changed

Lines changed: 19 additions & 18 deletions

File tree

Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Added
1313

14-
* a `ScaledManifoldObjective` to esier built scaled versions of objectives,
14+
* a `ScaledManifoldObjective` to easier build scaled versions of objectives,
1515
especially turn maximisation problems into minimisation ones using a scaling of `-1`.
1616

1717
## [0.5.11] April 8, 2025

src/plans/scaled_objective.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,67 +27,68 @@ Generate a scaled manifold objective based on `objective` with `scale` being `1`
2727
in the first, `scale=-1` in the second case. The multiplication from the left with a scalar
2828
is also overloaded.
2929
"""
30-
struct ScaledManifoldObjective{E,O2,O1<:AbstractManifoldObjective{E},F} <:
31-
AbstractDecoratedManifoldObjective{E,O2}
30+
struct ScaledManifoldObjective{
31+
E<:AbstractEvaluationType,O2,O1<:AbstractManifoldObjective{E},F
32+
} <: AbstractDecoratedManifoldObjective{E,O2}
3233
objective::O1
3334
scale::F
3435
end
3536
function ScaledManifoldObjective(
3637
objective::O, scale::F=1
37-
) where {E<:AbstractEvaluationType,O<:AbstractManifoldObjective{E},F}
38+
) where {E<:AbstractEvaluationType,O<:AbstractManifoldObjective{E},F<:Real}
3839
return ScaledManifoldObjective{E,O,O,F}(objective, scale)
3940
end
4041
function ScaledManifoldObjective(
4142
objective::O1, scale::F=1
4243
) where {
43-
F,
44+
F<:Real,
4445
E<:AbstractEvaluationType,
4546
O2<:AbstractManifoldObjective,
4647
O1<:AbstractDecoratedManifoldObjective{E,O2},
4748
}
4849
return ScaledManifoldObjective{E,O2,O1,F}(objective, scale)
4950
end
50-
Base.:-(objective::AbstractManifoldObjective) = ScaledManifoldObjective(objective, -1.0)
51-
function Base.:*(scale::F, objective::AbstractManifoldObjective) where {F}
51+
Base.:-(objective::AbstractManifoldObjective) = ScaledManifoldObjective(objective, -1)
52+
function Base.:*(scale::Real, objective::AbstractManifoldObjective)
5253
return ScaledManifoldObjective(objective, scale)
5354
end
5455

5556
@doc """
5657
get_cost(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p)
5758
58-
Evaluated the scaled objective. ``s*f(p)``
59+
Evaluate the scaled objective. ``s*f(p)``
5960
"""
6061
function get_cost(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p)
6162
return scaled_objective.scale * get_cost(M, scaled_objective.objective, p)
6263
end
6364

64-
function get_cost_function(scaled_objective::ScaledManifoldObjective, recursive=false)
65+
function get_cost_function(scaled_objective::ScaledManifoldObjective, recursive::Bool=false)
6566
recursive && (return get_cost_function(scaled_objective.objective, recursive))
6667
return (M, p) -> scaled_objective.scale * get_cost(M, scaled_objective, p)
6768
end
6869
@doc """
6970
get_gradient(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p)
7071
get_gradient!(M::AbstractManifold, X, scaled_objective::ScaledManifoldObjective, p)
7172
72-
Evaluated the scaled gradient. ``s*$(_tex(:grad))f(p)``
73+
Evaluate the scaled gradient. ``s*$(_tex(:grad))f(p)``
7374
"""
7475
function get_gradient(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p)
7576
return scaled_objective.scale * get_gradient(M, scaled_objective.objective, p)
7677
end
7778
function get_gradient!(M::AbstractManifold, X, scaled_objective::ScaledManifoldObjective, p)
7879
get_gradient!(M, X, scaled_objective.objective, p)
79-
X .= scaled_objective.scale * X
80+
X .= scaled_objective.scale .* X
8081
return X
8182
end
8283

8384
function get_gradient_function(
84-
scaled_objective::ScaledManifoldObjective{AllocatingEvaluation}, recursive=false
85+
scaled_objective::ScaledManifoldObjective{AllocatingEvaluation}, recursive::Bool=false
8586
)
8687
recursive && (return get_gradient_function(scaled_objective.objective, recursive))
8788
return (M, p) -> get_gradient(M, scaled_objective, p)
8889
end
8990
function get_gradient_function(
90-
scaled_objective::ScaledManifoldObjective{InplaceEvaluation}, recursive=false
91+
scaled_objective::ScaledManifoldObjective{InplaceEvaluation}, recursive::Bool=false
9192
)
9293
recursive && (return get_gradient_function(scaled_objective.objective, recursive))
9394
return (M, X, p) -> get_gradient!(M, X, scaled_objective, p)
@@ -99,7 +100,7 @@ end
99100
get_hessian(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p, X)
100101
get_hessian!(M::AbstractManifold, Y, scaled_objective::ScaledManifoldObjective, p, X)
101102
102-
Evaluated the scaled Hessian ``s*$(_tex(:Hess))f(p)``
103+
Evaluate the scaled Hessian ``s*$(_tex(:Hess))f(p)``
103104
"""
104105
function get_hessian(M::AbstractManifold, scaled_objective::ScaledManifoldObjective, p, X)
105106
return scaled_objective.scale * get_hessian(M, scaled_objective.objective, p, X)
@@ -108,7 +109,7 @@ function get_hessian!(
108109
M::AbstractManifold, Y, scaled_objective::ScaledManifoldObjective, p, X
109110
)
110111
get_hessian!(M, Y, scaled_objective.objective, p, X)
111-
Y .= scaled_objective.scale * Y
112+
Y .= scaled_objective.scale .* Y
112113
return Y
113114
end
114115

test/plans/test_scaled_objective.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ using LinearAlgebra, Manifolds, Manopt, Test, Random
1616
obj! = ManifoldHessianObjective(f, ∇f!, ∇²f!; evaluation=InplaceEvaluation())
1717
neg_obj = -obj
1818
@test neg_obj isa ScaledManifoldObjective
19-
s = "ScaledManifoldObjective based on a $(obj) with scale -1.0"
19+
s = "ScaledManifoldObjective based on a $(obj) with scale -1"
2020
@test repr(neg_obj) == s
21-
scaled_obj = -1.0 * obj
21+
scaled_obj = -1 * obj
2222
@test scaled_obj == neg_obj
2323
scaled_obj! = -1.0 * obj!
2424
# just verify that this also works for double decorated ones.
2525
deco_obj = ScaledManifoldObjective(ManifoldCountObjective(M, obj, [:Cost]), 0.5)
2626

2727
#
28-
# Test and comare all accessors
28+
# Test and compare all accessors
2929
#
3030
for (s, o) in zip([scaled_obj, scaled_obj!], [obj, obj!])
3131
@test get_cost(M, s, p) == -f(M, p)

0 commit comments

Comments
 (0)