Skip to content

Commit de07beb

Browse files
committed
Fix RTR with Float32
1 parent 3f44394 commit de07beb

4 files changed

Lines changed: 55 additions & 14 deletions

File tree

Changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The file was started with Version `0.4`.
66
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
77
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
88

9+
## [0.5.34] March 3, 2026
10+
11+
### Fixed
12+
13+
* `Float32` support in `trust_regions` solver was broken in the previous release, which is now fixed.
14+
915
## [0.5.33] February 18, 2026
1016

1117
### Added

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Manopt"
22
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
3-
version = "0.5.33"
3+
version = "0.5.34"
44
authors = ["Ronny Bergmann <manopt@ronnybergmann.net>"]
55

66
[workspace]

src/solvers/trust_regions.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -436,20 +436,20 @@ function trust_regions!(
436436
retraction_method::AbstractRetractionMethod = default_retraction_method(M, typeof(p)),
437437
stopping_criterion::StoppingCriterion = StopAfterIteration(1000) |
438438
StopWhenGradientNormLess(1.0e-6),
439-
max_trust_region_radius::R = sqrt(manifold_dimension(M)),
440-
trust_region_radius::R = max_trust_region_radius / 8,
439+
max_trust_region_radius::Real = sqrt(manifold_dimension(M)),
440+
trust_region_radius::Real = max_trust_region_radius / 8,
441441
randomize::Bool = false, # Deprecated, remove on next release (use just `σ`)
442442
project!::Proj = (copyto!),
443-
ρ_prime::R = 0.1, # Deprecated, remove on next breaking change (use `acceptance_rate``)
444-
acceptance_rate::R = ρ_prime,
445-
ρ_regularization = 1.0e3,
446-
θ::R = 1.0,
447-
κ::R = 0.1,
448-
σ = randomize ? 1.0e-3 : 0.0,
449-
reduction_threshold::R = 0.1,
450-
reduction_factor::R = 0.25,
451-
augmentation_threshold::R = 0.75,
452-
augmentation_factor::R = 2.0,
443+
ρ_prime::Real = 0.1, # Deprecated, remove on next breaking change (use `acceptance_rate`)
444+
acceptance_rate::Real = ρ_prime,
445+
ρ_regularization::Real = 1.0e3,
446+
θ::Real = 1.0,
447+
κ::Real = 0.1,
448+
σ::Real = randomize ? 1.0e-3 : 0.0,
449+
reduction_threshold::Real = 0.1,
450+
reduction_factor::Real = 0.25,
451+
augmentation_threshold::Real = 0.75,
452+
augmentation_factor::Real = 2.0,
453453
sub_kwargs = (;),
454454
sub_objective = decorate_objective!(M, TrustRegionModelObjective(mho); sub_kwargs...),
455455
sub_problem = DefaultManoptProblem(TangentSpace(M, p), sub_objective),
@@ -475,7 +475,33 @@ function trust_regions!(
475475
sub_kwargs...,
476476
),
477477
kwargs..., #collect rest
478-
) where {Proj, O <: Union{ManifoldHessianObjective, AbstractDecoratedManifoldObjective}, R}
478+
) where {Proj, O <: Union{ManifoldHessianObjective, AbstractDecoratedManifoldObjective}}
479+
R = float(
480+
promote_type(
481+
typeof(max_trust_region_radius),
482+
typeof(trust_region_radius),
483+
typeof(acceptance_rate),
484+
typeof(ρ_regularization),
485+
typeof(θ),
486+
typeof(κ),
487+
typeof(σ),
488+
typeof(reduction_threshold),
489+
typeof(reduction_factor),
490+
typeof(augmentation_threshold),
491+
typeof(augmentation_factor),
492+
),
493+
)
494+
max_trust_region_radius = convert(R, max_trust_region_radius)
495+
trust_region_radius = convert(R, trust_region_radius)
496+
acceptance_rate = convert(R, acceptance_rate)
497+
ρ_regularization = convert(R, ρ_regularization)
498+
θ = convert(R, θ)
499+
κ = convert(R, κ)
500+
σ = convert(R, σ)
501+
reduction_threshold = convert(R, reduction_threshold)
502+
reduction_factor = convert(R, reduction_factor)
503+
augmentation_threshold = convert(R, augmentation_threshold)
504+
augmentation_factor = convert(R, augmentation_factor)
479505
(max_trust_region_radius <= 0) && throw(
480506
ErrorException(
481507
"max_trust_region_radius must be positive but it is $max_trust_region_radius.",

test/solvers/test_trust_regions.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,4 +359,13 @@ include("trust_region_model.jl")
359359
@test f(M, q3) λ atol = 5 * 1.0e-8
360360
@test f(M, q4) λ atol = 5 * 1.0e-10
361361
end
362+
363+
@testset "Float32 support" begin
364+
M = Euclidean(3, 3)
365+
p = randn(Float32, 3, 3)
366+
f(M::Euclidean, p) = sum(p .^ 2) / 2
367+
grad(M::Euclidean, p) = p
368+
hess(M::Euclidean, p, X) = X
369+
trust_regions(M, f, grad, hess, p; max_trust_region_radius = 0.1f0)
370+
end
362371
end

0 commit comments

Comments
 (0)