Skip to content

Commit b764254

Browse files
authored
Add point capability to _produce_type to improve GPU compatibility (#577)
* Add point capability to `_produce_type` to improve GPU compatibility * fixes * fix two issues * slightly nicer fix * improve coverage * we actually need a copy there * set date in changelog
1 parent 9707fba commit b764254

18 files changed

Lines changed: 180 additions & 51 deletions

Changelog.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@ The file was started with Version `0.4`.
66
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
77
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
88

9-
## [0.5.33] unreleased
9+
## [0.5.33] February 18, 2026
1010

1111
### Added
1212

1313
* A clarification on the use of AI in the [CONTRIBUTING.md](https://manoptjl.org/stable/contributing/) (#573)
14+
* `_produce_type` now accepts the point `p` as an optional third argument, which can be used to produce objects with specific point type for internal buffers. The addition has been utilized in `DirectionUpdateRule`s and `Stepsize`s to improve GPU and custom floating point type compatibility. (#577)
1415
* Added another package and paper using `Manopt.jl` to the about page (#576).
1516

17+
### Fixed
18+
19+
* `DistanceOverGradientsStepsize` now requires explicitly passing a point as the second argument because it logically depends on receiving the initial point. (#577)
20+
1621
## [0.5.32] January 15, 2026
1722

1823
### Fixed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Manopt"
22
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
3-
version = "0.5.32"
3+
version = "0.5.33"
44
authors = ["Ronny Bergmann <manopt@ronnybergmann.net>"]
55

66
[workspace]
@@ -54,7 +54,7 @@ LineSearches = "7.2.0"
5454
LinearAlgebra = "1.10"
5555
ManifoldDiff = "0.3.8, 0.4"
5656
Manifolds = "0.11.2"
57-
ManifoldsBase = "2.2.0"
57+
ManifoldsBase = "2.3.1"
5858
Markdown = "1.10"
5959
Plots = "1.30"
6060
Preferences = "1.4"

src/plans/first_order_plan.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ $(_fields(:X; name = "X_old"))
643643
644644
645645
MomentumGradientRule(M::AbstractManifold; kwargs...)
646+
MomentumGradientRule(M::AbstractManifold, p; kwargs...)
646647
647648
Initialize a momentum gradient rule to `s`, where `p` and `X` are memory for interim values.
648649
@@ -665,6 +666,9 @@ mutable struct MomentumGradientRule{
665666
vector_transport_method::VTM
666667
X_old::T
667668
end
669+
function MomentumGradientRule(M::AbstractManifold, p; kwargs...)
670+
return MomentumGradientRule(M; p = copy(M, p), kwargs...)
671+
end
668672
function MomentumGradientRule(
669673
M::AbstractManifold;
670674
p::P = rand(M),
@@ -715,7 +719,7 @@ $(_kwargs(:vector_transport_method))
715719
$(_note(:ManifoldDefaultFactory, "MomentumGradientRule"))
716720
"""
717721
function MomentumGradient(args...; kwargs...)
718-
return ManifoldDefaultsFactory(Manopt.MomentumGradientRule, args...; kwargs...)
722+
return ManifoldDefaultsFactory(Manopt.MomentumGradientRule, args...; requires_point = true, kwargs...)
719723
end
720724

721725
"""
@@ -744,6 +748,7 @@ $(_kwargs(:vector_transport_method))
744748
last_iterate = deepcopy(x0),
745749
vector_transport_method = default_vector_transport_method(M, typeof(p))
746750
)
751+
AverageGradientRule(M::AbstractManifold, p; kwargs...)
747752
748753
Add average to a gradient problem, where
749754
@@ -761,6 +766,9 @@ mutable struct AverageGradientRule{
761766
direction::D
762767
vector_transport_method::VTM
763768
end
769+
function AverageGradientRule(M::AbstractManifold, p; kwargs...)
770+
return AverageGradientRule(M; p = copy(M, p), kwargs...)
771+
end
764772
function AverageGradientRule(
765773
M::AbstractManifold;
766774
p::P = rand(M),
@@ -813,7 +821,7 @@ $(_kwargs([:X, :vector_transport_method]))
813821
$(_note(:ManifoldDefaultFactory, "AverageGradientRule"))
814822
"""
815823
function AverageGradient(args...; kwargs...)
816-
return ManifoldDefaultsFactory(Manopt.AverageGradientRule, args...; kwargs...)
824+
return ManifoldDefaultsFactory(Manopt.AverageGradientRule, args...; requires_point = true, kwargs...)
817825
end
818826

819827
@doc """
@@ -832,6 +840,7 @@ $(_kwargs(:inverse_retraction_method))
832840
# Constructor
833841
834842
NesterovRule(M::AbstractManifold; kwargs...)
843+
NesterovRule(M::AbstractManifold, p; kwargs...)
835844
836845
## Keyword arguments
837846
@@ -852,6 +861,9 @@ mutable struct NesterovRule{P, R <: Real} <: DirectionUpdateRule
852861
shrinkage::Function
853862
inverse_retraction_method::AbstractInverseRetractionMethod
854863
end
864+
function NesterovRule(M::AbstractManifold, p; kwargs...)
865+
return NesterovRule(M; p = copy(M, p), kwargs...)
866+
end
855867
function NesterovRule(
856868
M::AbstractManifold;
857869
p::P = rand(M),
@@ -926,7 +938,7 @@ $(_kwargs(:inverse_retraction_method))
926938
$(_note(:ManifoldDefaultFactory, "NesterovRule"))
927939
"""
928940
function Nesterov(args...; kwargs...)
929-
return ManifoldDefaultsFactory(Manopt.NesterovRule, args...; kwargs...)
941+
return ManifoldDefaultsFactory(Manopt.NesterovRule, args...; requires_point = true, kwargs...)
930942
end
931943

932944
"""

src/plans/manifold_default_factory.jl

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,31 +56,63 @@ struct ManifoldDefaultsFactory{T, TM <: Union{<:AbstractManifold, Nothing}, A, K
5656
args::A
5757
kwargs::K
5858
constructor_requires_manifold::Bool
59+
constructor_requires_point::Bool
5960
end
6061
function ManifoldDefaultsFactory(
61-
T::Type, M::TM, args...; requires_manifold = true, kwargs...
62+
T::Type, M::TM, args...; requires_manifold = true, requires_point = false, kwargs...
6263
) where {TM <: AbstractManifold}
6364
return ManifoldDefaultsFactory{T, TM, typeof(args), typeof(kwargs)}(
64-
M, args, kwargs, requires_manifold
65+
M, args, kwargs, requires_manifold, requires_point
6566
)
6667
end
67-
function ManifoldDefaultsFactory(T::Type, args...; requires_manifold = true, kwargs...)
68+
function ManifoldDefaultsFactory(T::Type, args...; requires_manifold = true, requires_point = false, kwargs...)
6869
return ManifoldDefaultsFactory{T, Nothing, typeof(args), typeof(kwargs)}(
69-
nothing, args, kwargs, requires_manifold
70+
nothing, args, kwargs, requires_manifold, requires_point
7071
)
7172
end
73+
function (mdf::ManifoldDefaultsFactory{T})(M::AbstractManifold, p) where {T}
74+
return if mdf.constructor_requires_manifold
75+
if mdf.constructor_requires_point
76+
return T(M, p, mdf.args...; mdf.kwargs...)
77+
else
78+
return T(M, mdf.args...; mdf.kwargs...)
79+
end
80+
else
81+
if mdf.constructor_requires_point
82+
return T(p, mdf.args...; mdf.kwargs...)
83+
else
84+
return T(mdf.args...; mdf.kwargs...)
85+
end
86+
end
87+
end
7288
function (mdf::ManifoldDefaultsFactory{T})(M::AbstractManifold) where {T}
73-
if mdf.constructor_requires_manifold
74-
return T(M, mdf.args...; mdf.kwargs...)
89+
return if mdf.constructor_requires_manifold
90+
if mdf.constructor_requires_point
91+
return T(M, rand(M), mdf.args...; mdf.kwargs...)
92+
else
93+
return T(M, mdf.args...; mdf.kwargs...)
94+
end
7595
else
76-
return T(mdf.args...; mdf.kwargs...)
96+
if mdf.constructor_requires_point
97+
return T(rand(mdf.M), mdf.args...; mdf.kwargs...)
98+
else
99+
return T(mdf.args...; mdf.kwargs...)
100+
end
77101
end
78102
end
79103
function (mdf::ManifoldDefaultsFactory{T, <:AbstractManifold})() where {T}
80-
if mdf.constructor_requires_manifold
81-
return T(mdf.M, mdf.args...; mdf.kwargs...)
104+
return if mdf.constructor_requires_manifold
105+
if mdf.constructor_requires_point
106+
return T(mdf.M, rand(mdf.M), mdf.args...; mdf.kwargs...)
107+
else
108+
return T(mdf.M, mdf.args...; mdf.kwargs...)
109+
end
82110
else
83-
return T(mdf.args...; mdf.kwargs...)
111+
if mdf.constructor_requires_point
112+
return T(rand(mdf.M), mdf.args...; mdf.kwargs...)
113+
else
114+
return T(mdf.args...; mdf.kwargs...)
115+
end
84116
end
85117
end
86118
function (mdf::ManifoldDefaultsFactory{T, Nothing})() where {T}
@@ -90,13 +122,20 @@ end
90122
"""
91123
_produce_type(t::T, M::AbstractManifold)
92124
_produce_type(t::ManifoldDefaultsFactory{T}, M::AbstractManifold)
125+
_produce_type(t::ManifoldDefaultsFactory{T}, M::AbstractManifold, p)
93126
94127
Use the [`ManifoldDefaultsFactory`](@ref)`{T}` to produce an instance of type `T`.
95128
This acts transparent in the way that if you provide an instance `t::T` already, this will
96129
just be returned.
130+
131+
If a point `p` on manifold `M` is provided, it is passed to the constructor `t` as a
132+
template for allocating points. It is no supposed to be modified by the constructor or
133+
stored in the produced object.
97134
"""
98135
_produce_type(t, M::AbstractManifold) = t
136+
_produce_type(t, M::AbstractManifold, p) = t
99137
_produce_type(t::ManifoldDefaultsFactory, M::AbstractManifold) = t(M)
138+
_produce_type(t::ManifoldDefaultsFactory, M::AbstractManifold, p) = t(M, p)
100139

101140
function show(io::IO, mdf::ManifoldDefaultsFactory{T, M}) where {T, M}
102141
rm = mdf.constructor_requires_manifold

src/plans/stepsize/stepsize.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ mutable struct ArmijoLinesearchStepsize{TRM <: AbstractRetractionMethod, P, I, F
104104
)
105105
end
106106
end
107+
function ArmijoLinesearchStepsize(M::AbstractManifold, p; kwargs...)
108+
return ArmijoLinesearchStepsize(M; candidate_point = allocate(p), kwargs...)
109+
end
110+
107111
function (a::ArmijoLinesearchStepsize)(
108112
mp::AbstractManoptProblem,
109113
s::AbstractManoptSolverState,
@@ -222,7 +226,7 @@ For the stop safe guards you can pass `:Messages` to a `debug=` to see `@info` m
222226
$(_note(:ManifoldDefaultFactory, "ArmijoLinesearchStepsize"))
223227
"""
224228
function ArmijoLinesearch(args...; kwargs...)
225-
return ManifoldDefaultsFactory(Manopt.ArmijoLinesearchStepsize, args...; kwargs...)
229+
return ManifoldDefaultsFactory(Manopt.ArmijoLinesearchStepsize, args...; requires_point = true, kwargs...)
226230
end
227231

228232
@doc """
@@ -293,6 +297,9 @@ function AdaptiveWNGradientStepsize(
293297
0,
294298
)
295299
end
300+
function AdaptiveWNGradientStepsize(M::AbstractManifold, p; kwargs...)
301+
return AdaptiveWNGradientStepsize(M; p = p, kwargs...)
302+
end
296303
function (awng::AdaptiveWNGradientStepsize)(
297304
mp::AbstractManoptProblem,
298305
s::AbstractGradientSolverState,
@@ -402,7 +409,7 @@ $(_kwargs(:p)) only used to define the `gradient_bound`
402409
$(_kwargs(:X)) only used to define the `gradient_bound`
403410
"""
404411
function AdaptiveWNGradient(args...; kwargs...)
405-
return ManifoldDefaultsFactory(Manopt.AdaptiveWNGradientStepsize, args...; kwargs...)
412+
return ManifoldDefaultsFactory(Manopt.AdaptiveWNGradientStepsize, args...; requires_point = true, kwargs...)
406413
end
407414

408415
"""
@@ -508,6 +515,7 @@ $(_fields(:vector_transport_method))
508515
# Constructor
509516
510517
CubicBracketingLinesearchStepsize(M::AbstractManifold; kwargs...)
518+
CubicBracketingLinesearchStepsize(M::AbstractManifold, p; kwargs...)
511519
512520
## Keyword arguments
513521
@@ -559,6 +567,13 @@ mutable struct CubicBracketingLinesearchStepsize{
559567
return new{R, I, TRM, VTM, P, T}(candidate_direction, candidate_point, initial_stepsize, initial_stepsize, retraction_method, stepsize_increase, max_iterations, sufficient_curvature, min_bracket_width, hybrid, vector_transport_method, max_stepsize)
560568
end
561569
end
570+
function CubicBracketingLinesearchStepsize(M::AbstractManifold, p; kwargs...)
571+
candidate_point = allocate(p)
572+
candidate_direction = zero_vector(M, candidate_point)
573+
return CubicBracketingLinesearchStepsize(
574+
M; candidate_point = candidate_point, candidate_direction = candidate_direction, kwargs...
575+
)
576+
end
562577

563578
"""
564579
UnivariateTriple{R <: Real}
@@ -848,7 +863,7 @@ $(_kwargs(:vector_transport_method))
848863
$(_note(:ManifoldDefaultFactory, "CubicBracketingLinesearch"))
849864
"""
850865
function CubicBracketingLinesearch(args...; kwargs...)
851-
return ManifoldDefaultsFactory(CubicBracketingLinesearchStepsize, args...; kwargs...)
866+
return ManifoldDefaultsFactory(CubicBracketingLinesearchStepsize, args...; requires_point = true, kwargs...)
852867
end
853868

854869

@@ -992,8 +1007,8 @@ mutable struct DistanceOverGradientsStepsize{R <: Real, P} <: Stepsize
9921007
end
9931008

9941009
function DistanceOverGradientsStepsize(
995-
M::AbstractManifold;
996-
p = rand(M),
1010+
M::AbstractManifold,
1011+
p;
9971012
initial_distance::R1 = 1.0e-3,
9981013
use_curvature::Bool = false,
9991014
sectional_curvature_bound::R2 = 0.0,
@@ -1164,7 +1179,7 @@ $(doc_DoG_main)
11641179
$(_note(:ManifoldDefaultFactory, "DistanceOverGradientsStepsize"))
11651180
"""
11661181
function DistanceOverGradients(args...; kwargs...)
1167-
return ManifoldDefaultsFactory(Manopt.DistanceOverGradientsStepsize, args...; kwargs...)
1182+
return ManifoldDefaultsFactory(Manopt.DistanceOverGradientsStepsize, args...; requires_point = true, kwargs...)
11681183
end
11691184

11701185
@doc """
@@ -1194,6 +1209,7 @@ $(_kwargs(:vector_transport_method))
11941209
# Constructor
11951210
11961211
NonmonotoneLinesearchStepsize(M::AbstractManifold; kwargs...)
1212+
NonmonotoneLinesearchStepsize(M::AbstractManifold, p; kwargs...)
11971213
11981214
## Keyword arguments
11991215
@@ -1315,6 +1331,9 @@ mutable struct NonmonotoneLinesearchStepsize{
13151331
)
13161332
end
13171333
end
1334+
function NonmonotoneLinesearchStepsize(M::AbstractManifold, p; kwargs...)
1335+
return NonmonotoneLinesearchStepsize(M; p = allocate(p), kwargs...)
1336+
end
13181337
function (a::NonmonotoneLinesearchStepsize)(
13191338
mp::AbstractManoptProblem,
13201339
s::AbstractManoptSolverState,
@@ -1529,7 +1548,7 @@ $(_kwargs(:retraction_method))
15291548
* `stop_decreasing_at_step=1000`: last step size to decrease the stepsize (phase 2),
15301549
"""
15311550
function NonmonotoneLinesearch(args...; kwargs...)
1532-
return ManifoldDefaultsFactory(NonmonotoneLinesearchStepsize, args...; kwargs...)
1551+
return ManifoldDefaultsFactory(NonmonotoneLinesearchStepsize, args...; requires_point = true, kwargs...)
15331552
end
15341553

15351554
@doc """
@@ -1640,6 +1659,7 @@ $(_fields(:vector_transport_method))
16401659
# Constructor
16411660
16421661
WolfePowellLinesearchStepsize(M::AbstractManifold; kwargs...)
1662+
WolfePowellLinesearchStepsize(M::AbstractManifold, p; kwargs...)
16431663
16441664
## Keyword arguments
16451665
@@ -1705,6 +1725,13 @@ mutable struct WolfePowellLinesearchStepsize{
17051725
)
17061726
end
17071727
end
1728+
function WolfePowellLinesearchStepsize(M::AbstractManifold, p; kwargs...)
1729+
candidate_point = allocate(p)
1730+
candidate_direction = zero_vector(M, candidate_point)
1731+
return WolfePowellLinesearchStepsize(
1732+
M; p = candidate_point, X = candidate_direction, kwargs...
1733+
)
1734+
end
17081735
function (a::WolfePowellLinesearchStepsize)(
17091736
mp::AbstractManoptProblem,
17101737
ams::AbstractManoptSolverState,
@@ -1847,7 +1874,7 @@ $(_kwargs(:retraction_method))
18471874
$(_kwargs(:vector_transport_method))
18481875
"""
18491876
function WolfePowellLinesearch(args...; kwargs...)
1850-
return ManifoldDefaultsFactory(WolfePowellLinesearchStepsize, args...; kwargs...)
1877+
return ManifoldDefaultsFactory(WolfePowellLinesearchStepsize, args...; requires_point = true, kwargs...)
18511878
end
18521879

18531880
@doc """

src/solvers/FrankWolfe.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,10 @@ function Frank_Wolfe_method!(
284284
M;
285285
p = copy(M, p),
286286
stopping_criterion = sub_stopping_criterion,
287-
stepsize = default_stepsize(
288-
M, GradientDescentState; retraction_method = retraction_method
287+
stepsize = _produce_type(
288+
default_stepsize(
289+
M, GradientDescentState; retraction_method = retraction_method
290+
), M, p
289291
),
290292
sub_kwargs...,
291293
);
@@ -309,7 +311,7 @@ function Frank_Wolfe_method!(
309311
p = p,
310312
X = X,
311313
retraction_method = retraction_method,
312-
stepsize = _produce_type(stepsize, M),
314+
stepsize = _produce_type(stepsize, M, p),
313315
stopping_criterion = stopping_criterion,
314316
)
315317
dfws = decorate_state!(fws; kwargs...)

src/solvers/conjugate_gradient_descent.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ function conjugate_gradient_descent!(
160160
M;
161161
p = p,
162162
stopping_criterion = stopping_criterion,
163-
stepsize = _produce_type(stepsize, M),
164-
coefficient = _produce_type(coefficient, M),
163+
stepsize = _produce_type(stepsize, M, p),
164+
coefficient = _produce_type(coefficient, M, p),
165165
restart_condition = restart_condition,
166166
retraction_method = retraction_method,
167167
vector_transport_method = vector_transport_method,

0 commit comments

Comments
 (0)