@@ -75,36 +75,33 @@ mutable struct TrustRegionsState{
7575 sub_state:: St
7676 p_proposal:: P
7777 f_proposal:: R
78+ σ:: R
79+ reduction_threshold:: R
80+ reduction_factor:: R
81+ augmentation_threshold:: R
82+ augmentation_factor:: R
7883 # Only required for Random mode Random
7984 HX:: T
8085 Y:: T
8186 HY:: T
8287 Z:: T
8388 HZ:: T
8489 τ:: R
85- σ:: R
86- reduction_threshold:: R
87- reduction_factor:: R
88- augmentation_threshold:: R
89- augmentation_factor:: R
90- function TrustRegionsState {P, T, Pr, St, SC, RTR, R, Proj} (
91- p:: P ,
92- X:: T ,
93- trust_region_radius:: R ,
94- max_trust_region_radius:: R ,
95- acceptance_rate:: R ,
96- ρ_regularization:: R ,
97- randomize:: Bool ,
98- stopping_criterion:: SC ,
99- retraction_method:: RTR ,
100- reduction_threshold:: R ,
101- augmentation_threshold:: R ,
102- sub_problem:: Pr ,
103- sub_state:: St ,
104- project!:: Proj = (copyto!),
105- reduction_factor = 0.25 ,
106- augmentation_factor = 2.0 ,
107- σ:: R = random ? 1.0e-6 : 0.0 ,
90+ function TrustRegionsState (
91+ sub_problem:: Pr , sub_state:: St ;
92+ p:: P , X:: T ,
93+ trust_region_radius:: R , max_trust_region_radius:: R , acceptance_rate:: R ,
94+ ρ_regularization:: R , randomize:: Bool ,
95+ stopping_criterion:: SC , retraction_method:: RTR , reduction_threshold:: R ,
96+ augmentation_threshold:: R , project!:: Proj = (copyto!),
97+ reduction_factor:: R , augmentation_factor:: R , σ:: R ,
98+ # random mode ones can stay uninitielized if not provided
99+ HX:: Union{T, Nothing} = nothing ,
100+ Y:: Union{T, Nothing} = nothing ,
101+ HY:: Union{T, Nothing} = nothing ,
102+ Z:: Union{T, Nothing} = nothing ,
103+ HZ:: Union{T, Nothing} = nothing ,
104+ τ:: Union{R, Nothing} = nothing ,
108105 ) where {
109106 P, T, Pr, St <: AbstractManoptSolverState ,
110107 SC <: StoppingCriterion , RTR <: AbstractRetractionMethod , R <: Real , Proj,
@@ -127,51 +124,52 @@ mutable struct TrustRegionsState{
127124 trs. augmentation_factor = augmentation_factor
128125 trs. project! = project!
129126 trs. σ = σ
127+ ! isnothing (HX) && (trs. HX = HX)
128+ ! isnothing (Y) && (trs. Y = Y)
129+ ! isnothing (HY) && (trs. HY = HY)
130+ ! isnothing (Z) && (trs. HZ = Z)
131+ ! isnothing (HZ) && (trs. HZ = HZ)
132+ ! isnothing (τ) && (trs. τ = τ)
130133 return trs
131134 end
132135end
133-
136+ TrustRegionsState (M :: AbstractManifold , st :: AbstractManoptSolverState ; kwargs ... ) = error ( " Trust region method state can not be constructed based on $M and the sub state $st , a sub_problem is missing " )
134137function TrustRegionsState (
135138 M:: AbstractManifold , sub_problem:: Pr , sub_state:: St ;
136139 p:: P = rand (M), X:: T = zero_vector (M, p),
137- acceptance_rate = 0.1 ,
138- ρ_regularization:: R = 1000.0 ,
140+ acceptance_rate:: Real = 0.1 , ρ_regularization:: Real = 1000.0 ,
139141 randomize:: Bool = false ,
140142 stopping_criterion:: SC = StopAfterIteration (1000 ) | StopWhenGradientNormLess (1.0e-6 ),
141- max_trust_region_radius:: R = sqrt (manifold_dimension (M)),
142- trust_region_radius:: R = max_trust_region_radius / 8 ,
143+ max_trust_region_radius:: Real = sqrt (manifold_dimension (M)),
144+ trust_region_radius:: Real = max_trust_region_radius / 8 ,
143145 retraction_method:: RTR = default_retraction_method (M, typeof (p)),
144- reduction_threshold:: R = 0.1 ,
145- reduction_factor = 0.25 ,
146- augmentation_threshold:: R = 0.75 ,
147- augmentation_factor = 2.0 ,
148- project!:: Proj = (copyto!),
149- σ = randomize ? 1.0e-4 : 0.0 ,
146+ reduction_threshold:: Real = 0.1 , reduction_factor = 0.25 ,
147+ augmentation_threshold:: Real = 0.75 , augmentation_factor:: Real = 2.0 ,
148+ project!:: Proj = (copyto!), σ:: Real = randomize ? 1.0e-4 : 0.0 ,
150149 ) where {
151150 P, T, Pr <: Union{AbstractManoptProblem, F} where {F}, St <: AbstractManoptSolverState ,
152- R <: Real ,
153- SC <: StoppingCriterion ,
154- RTR <: AbstractRetractionMethod ,
155- Proj,
151+ SC <: StoppingCriterion , RTR <: AbstractRetractionMethod , Proj,
156152 }
157- return TrustRegionsState {P, T, Pr, St, SC, RTR, R, Proj} (
158- p,
159- X,
160- trust_region_radius,
161- max_trust_region_radius,
162- acceptance_rate,
163- ρ_regularization,
164- randomize,
165- stopping_criterion,
166- retraction_method,
167- reduction_threshold,
168- augmentation_threshold,
169- sub_problem,
170- sub_state,
171- project!,
172- reduction_factor,
173- augmentation_factor,
174- σ,
153+ R = promote_type (
154+ typeof (acceptance_rate), typeof (ρ_regularization), typeof (max_trust_region_radius),
155+ typeof (trust_region_radius), typeof (reduction_threshold), typeof (reduction_factor),
156+ typeof (augmentation_factor), typeof (augmentation_threshold), typeof (σ)
157+ )
158+ acceptance_rate = convert (R, acceptance_rate); ρ_regularization = convert (R, ρ_regularization)
159+ max_trust_region_radius = convert (R, max_trust_region_radius); trust_region_radius = convert (R, trust_region_radius)
160+ reduction_threshold = convert (R, reduction_threshold); reduction_factor = convert (R, reduction_factor)
161+ augmentation_factor = convert (R, augmentation_factor); augmentation_threshold = convert (R, augmentation_threshold)
162+ σ = convert (R, σ)
163+
164+ return TrustRegionsState (
165+ sub_problem, sub_state;
166+ p = p, X = X,
167+ trust_region_radius = trust_region_radius, max_trust_region_radius = max_trust_region_radius,
168+ acceptance_rate = acceptance_rate, ρ_regularization = ρ_regularization,
169+ (project!) = project!, randomize = randomize, σ = σ,
170+ stopping_criterion = stopping_criterion, retraction_method = retraction_method,
171+ reduction_threshold = reduction_threshold, augmentation_threshold = augmentation_threshold,
172+ reduction_factor = reduction_factor, augmentation_factor = augmentation_factor,
175173 )
176174end
177175function TrustRegionsState (
@@ -203,14 +201,31 @@ function get_message(dcs::TrustRegionsState)
203201 # for now only the sub solver might have messages
204202 return get_message (dcs. sub_state)
205203end
204+ function Base. show (io:: IO , trs:: TrustRegionsState )
205+ print (io, " TrustRegionsState(" ); print (io, trs. sub_problem); print (io, " , " ); print (io, trs. sub_state)
206+ print (io, " ; " )
207+ print (io, " p = $(trs. p) , X = $(trs. X) , " )
208+ print (io, " trust_region_radius = $(trs. trust_region_radius) , max_trust_region_radius = $(trs. max_trust_region_radius) , " )
209+ print (io, " acceptance_rate = $(trs. acceptance_rate) , ρ_regularization = $(trs. ρ_regularization) , randomize = $(trs. randomize) , " )
210+ print (io, " reduction_threshold = $(trs. reduction_threshold) , augmentation_threshold = $(trs. augmentation_threshold) , " )
211+ print (io, " (project!) = $(trs. project!) , reduction_factor = $(trs. reduction_factor) , augmentation_factor = $(trs. augmentation_factor) , σ = $(trs. σ) , " )
212+ isdefined (trs, :HX ) && print (io, " HX = $(trs. HX) , " )
213+ isdefined (trs, :Y ) && print (io, " Y = $(trs. Y) , " )
214+ isdefined (trs, :HY ) && print (io, " HY = $(trs. HY) , " )
215+ isdefined (trs, :Z ) && print (io, " Z = $(trs. Z) , " )
216+ isdefined (trs, :HZ ) && print (io, " HZ = $(trs. HZ) , " )
217+ isdefined (trs, :τ ) && print (io, " τ = $(trs. τ) , " )
218+ print (io, " stopping_criterion = $(trs. stop) , retraction_method = $(trs. retraction_method) " )
219+ return print (io, " )" )
220+ end
206221function status_summary (trs:: TrustRegionsState ; context:: Symbol = :default )
207222 (context === :short ) && return repr (trs)
208223 i = get_count (trs, :Iterations )
209224 conv_inl = (i > 0 ) ? (indicates_convergence (trs. stop) ? " (converged" : " (stopped" ) * " after $i iterations)" : " "
210225 (context === :inline ) && return " A solver state for the trust region solver$(conv_inl) "
211226 Iter = (i > 0 ) ? " After $i iterations\n " : " "
212227 Conv = indicates_convergence (trs. stop) ? " Yes" : " No"
213- _is_inline (context) && (return " $( repr (trs)) – $(Iter) $(has_converged (trs) ? " (converged)" : " " ) " )
228+ (context === :inline ) && (return " A trust regions method state – $(Iter) $(has_converged (trs) ? " (converged)" : " " ) " )
214229 sub = _in_str (status_summary (trs. sub_state; context = context); indent = 1 , headers = 1 , indent_end = " | " )
215230 s = """
216231 # Solver state for `Manopt.jl`s Trust Region Method
0 commit comments