@@ -436,20 +436,20 @@ function trust_regions!(
436436 retraction_method:: AbstractRetractionMethod = default_retraction_method (M, typeof (p)),
437437 stopping_criterion:: StoppingCriterion = StopAfterIteration (1000 ) |
438438 StopWhenGradientNormLess (1.0e-6 ),
439- max_trust_region_radius:: R = sqrt (manifold_dimension (M)),
440- trust_region_radius:: R = max_trust_region_radius / 8 ,
439+ max_trust_region_radius:: Real = sqrt (manifold_dimension (M)),
440+ trust_region_radius:: Real = max_trust_region_radius / 8 ,
441441 randomize:: Bool = false , # Deprecated, remove on next release (use just `σ`)
442442 project!:: Proj = (copyto!),
443- ρ_prime:: R = 0.1 , # Deprecated, remove on next breaking change (use `acceptance_rate` `)
444- acceptance_rate:: R = ρ_prime,
445- ρ_regularization = 1.0e3 ,
446- θ:: R = 1.0 ,
447- κ:: R = 0.1 ,
448- σ = randomize ? 1.0e-3 : 0.0 ,
449- reduction_threshold:: R = 0.1 ,
450- reduction_factor:: R = 0.25 ,
451- augmentation_threshold:: R = 0.75 ,
452- augmentation_factor:: R = 2.0 ,
443+ ρ_prime:: Real = 0.1 , # Deprecated, remove on next breaking change (use `acceptance_rate`)
444+ acceptance_rate:: Real = ρ_prime,
445+ ρ_regularization:: Real = 1.0e3 ,
446+ θ:: Real = 1.0 ,
447+ κ:: Real = 0.1 ,
448+ σ:: Real = randomize ? 1.0e-3 : 0.0 ,
449+ reduction_threshold:: Real = 0.1 ,
450+ reduction_factor:: Real = 0.25 ,
451+ augmentation_threshold:: Real = 0.75 ,
452+ augmentation_factor:: Real = 2.0 ,
453453 sub_kwargs = (;),
454454 sub_objective = decorate_objective! (M, TrustRegionModelObjective (mho); sub_kwargs... ),
455455 sub_problem = DefaultManoptProblem (TangentSpace (M, p), sub_objective),
@@ -475,7 +475,33 @@ function trust_regions!(
475475 sub_kwargs... ,
476476 ),
477477 kwargs... , # collect rest
478- ) where {Proj, O <: Union{ManifoldHessianObjective, AbstractDecoratedManifoldObjective} , R}
478+ ) where {Proj, O <: Union{ManifoldHessianObjective, AbstractDecoratedManifoldObjective} }
479+ R = float (
480+ promote_type (
481+ typeof (max_trust_region_radius),
482+ typeof (trust_region_radius),
483+ typeof (acceptance_rate),
484+ typeof (ρ_regularization),
485+ typeof (θ),
486+ typeof (κ),
487+ typeof (σ),
488+ typeof (reduction_threshold),
489+ typeof (reduction_factor),
490+ typeof (augmentation_threshold),
491+ typeof (augmentation_factor),
492+ ),
493+ )
494+ max_trust_region_radius = convert (R, max_trust_region_radius)
495+ trust_region_radius = convert (R, trust_region_radius)
496+ acceptance_rate = convert (R, acceptance_rate)
497+ ρ_regularization = convert (R, ρ_regularization)
498+ θ = convert (R, θ)
499+ κ = convert (R, κ)
500+ σ = convert (R, σ)
501+ reduction_threshold = convert (R, reduction_threshold)
502+ reduction_factor = convert (R, reduction_factor)
503+ augmentation_threshold = convert (R, augmentation_threshold)
504+ augmentation_factor = convert (R, augmentation_factor)
479505 (max_trust_region_radius <= 0 ) && throw (
480506 ErrorException (
481507 " max_trust_region_radius must be positive but it is $max_trust_region_radius ." ,
0 commit comments