Skip to content

Commit 2e7741e

Browse files
committed
Adapt Particle Swarm state.
1 parent b96d95d commit 2e7741e

2 files changed

Lines changed: 58 additions & 67 deletions

File tree

src/solvers/particle_swarm.jl

Lines changed: 57 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -48,72 +48,72 @@ $(_kwargs(:vector_transport_method))
4848
[`particle_swarm`](@ref)
4949
"""
5050
mutable struct ParticleSwarmState{
51-
P,
52-
T,
53-
TX <: AbstractVector{P},
54-
TVelocity <: AbstractVector{T},
55-
TParams <: Real,
56-
TStopping <: StoppingCriterion,
57-
TRetraction <: AbstractRetractionMethod,
58-
TInvRetraction <: AbstractInverseRetractionMethod,
59-
TVTM <: AbstractVectorTransportMethod,
51+
P, T, F <: Real, VP <: AbstractVector{P}, VT <: AbstractVector{T},
52+
SC <: StoppingCriterion, RM <: AbstractRetractionMethod,
53+
IRM <: AbstractInverseRetractionMethod, VTM <: AbstractVectorTransportMethod,
6054
} <: AbstractManoptSolverState
61-
swarm::TX
62-
positional_best::TX
55+
swarm::VP
56+
positional_best::VP
6357
p::P
64-
velocity::TVelocity
65-
inertia::TParams
66-
social_weight::TParams
67-
cognitive_weight::TParams
58+
velocity::VT
59+
inertia::F
60+
social_weight::F
61+
cognitive_weight::F
6862
q::P
6963
social_vector::T
7064
cognitive_vector::T
71-
stop::TStopping
72-
retraction_method::TRetraction
73-
inverse_retraction_method::TInvRetraction
74-
vector_transport_method::TVTM
75-
65+
stop::SC
66+
retraction_method::RM
67+
inverse_retraction_method::IRM
68+
vector_transport_method::VTM
69+
function ParticleSwarmState(;
70+
swarm::VP, positional_best::VP, p::P, velocity::VT,
71+
inertia::F, social_weight::F, cognitive_weight::F, q::P, social_vector::T,
72+
cognitive_vector::T, stopping_criterion::SC, retraction_method::RM,
73+
inverse_retraction_method::IRM, vector_transport_method::VTM
74+
) where {
75+
P, T, F, VP <: AbstractVector, VT <: AbstractVector, SC <: StoppingCriterion,
76+
RM <: AbstractRetractionMethod, IRM <: AbstractInverseRetractionMethod, VTM <: AbstractVectorTransportMethod,
77+
}
78+
return new{P, T, F, VP, VT, SC, RM, IRM, VTM}(
79+
swarm, positional_best, p, velocity, inertia, social_weight, cognitive_weight, q,
80+
social_vector, cognitive_vector, stopping_criterion, retraction_method,
81+
inverse_retraction_method, vector_transport_method
82+
)
83+
end
7684
function ParticleSwarmState(
77-
M::AbstractManifold,
78-
swarm::VP,
79-
velocity::VT;
80-
inertia = 0.65,
81-
social_weight = 1.4,
82-
cognitive_weight = 1.4,
85+
M::AbstractManifold, swarm::VP, velocity::VT;
86+
inertia::Real = 0.65, social_weight::Real = 1.4, cognitive_weight::Real = 1.4,
8387
stopping_criterion::SCT = StopAfterIteration(500) | StopWhenChangeLess(M, 1.0e-4),
8488
retraction_method::RTM = default_retraction_method(M, eltype(swarm)),
8589
inverse_retraction_method::IRM = default_inverse_retraction_method(M, eltype(swarm)),
8690
vector_transport_method::VTM = default_vector_transport_method(M, eltype(swarm)),
8791
) where {
88-
P,
89-
T,
90-
VP <: AbstractVector{<:P},
91-
VT <: AbstractVector{<:T},
92-
RTM <: AbstractRetractionMethod,
93-
SCT <: StoppingCriterion,
94-
IRM <: AbstractInverseRetractionMethod,
95-
VTM <: AbstractVectorTransportMethod,
92+
P, T, VP <: AbstractVector{<:P}, VT <: AbstractVector{<:T},
93+
RTM <: AbstractRetractionMethod, SCT <: StoppingCriterion,
94+
IRM <: AbstractInverseRetractionMethod, VTM <: AbstractVectorTransportMethod,
9695
}
97-
s = new{
98-
P, T, VP, VT, typeof(inertia + social_weight + cognitive_weight), SCT, RTM, IRM, VTM,
99-
}()
100-
s.swarm = swarm
101-
s.positional_best = copy.(Ref(M), swarm)
102-
s.q = copy(M, first(swarm))
103-
s.p = copy(M, first(swarm))
104-
s.social_vector = zero_vector(M, s.q)
105-
s.cognitive_vector = zero_vector(M, s.q)
106-
s.velocity = velocity
107-
s.inertia = inertia
108-
s.social_weight = social_weight
109-
s.cognitive_weight = cognitive_weight
110-
s.stop = stopping_criterion
111-
s.retraction_method = retraction_method
112-
s.inverse_retraction_method = inverse_retraction_method
113-
s.vector_transport_method = vector_transport_method
114-
return s
96+
R = promote_type(typeof(inertia), typeof(social_weight), typeof(cognitive_weight))
97+
inertia = convert(R, inertia); social_weight = convert(R, social_weight); cognitive_weight = convert(R, cognitive_weight)
98+
return ParticleSwarmState(;
99+
swarm = swarm, positional_best = copy.(Ref(M), swarm),
100+
q = copy(M, first(swarm)), p = copy(M, first(swarm)),
101+
social_vector = zero_vector(M, first(swarm)), cognitive_vector = zero_vector(M, first(swarm)),
102+
velocity = velocity, inertia = inertia, social_weight = social_weight, cognitive_weight = cognitive_weight,
103+
stopping_criterion = stopping_criterion, retraction_method = retraction_method,
104+
inverse_retraction_method = inverse_retraction_method, vector_transport_method = vector_transport_method
105+
)
115106
end
116107
end
108+
function Base.show(io::IO, pss::ParticleSwarmState)
109+
print(io, "ParticleSwarmState(; swarm = ", pss.swarm, ", positional_best = ", pss.positional_best)
110+
print(io, ", p = ", pss.p, ", velocity = ", pss.velocity)
111+
print(io, ", inertia = ", pss.inertia, ", social_weight = ", pss.social_weight, ", cognitive_weight = ", pss.cognitive_weight)
112+
print(io, ", q = ", pss.q, ", social_vector = ", pss.social_vector, ", cognitive_vector = ", pss.cognitive_vector)
113+
print(io, ", stopping_criterion = ", pss.stop, ", retraction_method = ", pss.retraction_method)
114+
print(io, ", inverse_retraction_method = ", pss.inverse_retraction_method, ", vector_transport_method = ", pss.vector_transport_method)
115+
return print(io, ")")
116+
end
117117
function status_summary(pss::ParticleSwarmState; context::Symbol = :default)
118118
(context === :short) && return repr(pss)
119119
i = get_count(pss, :Iterations)
@@ -303,16 +303,10 @@ function particle_swarm!(
303303
dmco = decorate_objective!(M, mco; kwargs...)
304304
mp = DefaultManoptProblem(M, dmco)
305305
pss = ParticleSwarmState(
306-
M,
307-
swarm,
308-
velocity;
309-
inertia = inertia,
310-
social_weight = social_weight,
311-
cognitive_weight = cognitive_weight,
312-
stopping_criterion = stopping_criterion,
313-
retraction_method = retraction_method,
314-
inverse_retraction_method = inverse_retraction_method,
315-
vector_transport_method = vector_transport_method,
306+
M, swarm, velocity;
307+
inertia = inertia, social_weight = social_weight, cognitive_weight = cognitive_weight,
308+
stopping_criterion = stopping_criterion, retraction_method = retraction_method,
309+
inverse_retraction_method = inverse_retraction_method, vector_transport_method = vector_transport_method,
316310
)
317311
dpss = decorate_state!(pss; kwargs...)
318312
solve!(mp, dpss)
@@ -333,11 +327,7 @@ function step_solver!(mp::AbstractManoptProblem, s::ParticleSwarmState, ::Any)
333327
M = get_manifold(mp)
334328
for i in 1:length(s.swarm)
335329
inverse_retract!(
336-
M,
337-
s.cognitive_vector,
338-
s.swarm[i],
339-
s.positional_best[i],
340-
s.inverse_retraction_method,
330+
M, s.cognitive_vector, s.swarm[i], s.positional_best[i], s.inverse_retraction_method,
341331
)
342332
inverse_retract!(M, s.social_vector, s.swarm[i], s.p, s.inverse_retraction_method)
343333
s.velocity[i] .=

test/solvers/test_particle_swarm.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Random
1515
Manopt.status_summary(o; context = :default),
1616
"# Solver state for `Manopt.jl`s Particle Swarm Optimization Algorithm\n"
1717
)
18+
@test startswith(repr(o), "ParticleSwarmState(;")
1819
g = get_solver_result(o)
1920

2021
initF = min(f.(Ref(M), p1)...)

0 commit comments

Comments
 (0)