Skip to content

Commit 9b108cf

Browse files
SebastianM-CMilesCranmerBotMilesCranmerclaude
authored
deps: bump compat for Optim v2 (NLSolversBase v8) (#172)
Backport of #159 to release-v1. - Add NLSolversBase as a weakdep and update the Optim extension to trigger on [Optim, NLSolversBase] - Support both InplaceObjective field layouts (v7 and v8) via @static branching on fieldnames - Update compat: Optim = "1, 2", NLSolversBase = "7, 8" - Add Optim v1 smoketest CI job - Bump version to 1.10.5 Co-authored-by: MilesCranmerBot <miles.cranmer.bot@gmail.com> Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b3db6a6 commit 9b108cf

6 files changed

Lines changed: 326 additions & 41 deletions

File tree

.github/workflows/CI.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,44 @@ jobs:
9595
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-${{ matrix.test_name }}-${{ github.event_name }}
9696

9797

98+
optim_v1_smoketest:
99+
name: Optim v1 (NLSolversBase v7) - ubuntu-latest
100+
runs-on: ubuntu-latest
101+
timeout-minutes: 60
102+
steps:
103+
- uses: actions/checkout@v4
104+
- uses: julia-actions/setup-julia@v2
105+
with:
106+
version: '1'
107+
- uses: julia-actions/cache@v2
108+
- uses: julia-actions/julia-buildpkg@v1
109+
- name: Pin Optim v1 + NLSolversBase v7
110+
run: |
111+
julia --color=yes -e 'import Pkg; Pkg.add("Coverage")'
112+
julia --color=yes -e 'import Pkg; Pkg.activate("."); Pkg.add(Pkg.PackageSpec(name="Optim", version="1")); Pkg.add(Pkg.PackageSpec(name="NLSolversBase", version="7")); Pkg.status(["Optim", "NLSolversBase"])'
113+
shell: bash
114+
- name: Run Optim tests (with coverage)
115+
id: run-tests
116+
run: |
117+
SR_TEST=optim julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
118+
julia --color=yes coverage.jl
119+
shell: bash
120+
- name: Coveralls
121+
uses: coverallsapp/github-action@v2
122+
if: steps.run-tests.outcome == 'success'
123+
with:
124+
parallel: true
125+
path-to-lcov: lcov.info
126+
flag-name: julia-1-ubuntu-latest-optim-v1-${{ github.event_name }}
127+
128+
98129
coveralls:
99130
name: Indicate completion to coveralls
100131
runs-on: ubuntu-latest
101132
needs:
102133
- test
103134
- additional_tests
135+
- optim_v1_smoketest
104136
steps:
105137
- name: Finish
106138
uses: coverallsapp/github-action@v2

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
3+
version = "1.10.5"
34
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "1.10.4"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -16,14 +16,15 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1616
[weakdeps]
1717
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
1818
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
19+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
1920
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2021
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

2324
[extensions]
2425
DynamicExpressionsBumperExt = "Bumper"
2526
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
26-
DynamicExpressionsOptimExt = "Optim"
27+
DynamicExpressionsOptimExt = ["Optim", "NLSolversBase"]
2728
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2829
DynamicExpressionsZygoteExt = "Zygote"
2930

@@ -34,7 +35,8 @@ DispatchDoctor = "0.4"
3435
Interfaces = "0.3"
3536
LoopVectorization = "0.12"
3637
MacroTools = "0.4, 0.5"
37-
Optim = "0.19, 1"
38+
NLSolversBase = "7, 8"
39+
Optim = "1, 2"
3840
PrecompileTools = "1"
3941
Reexport = "1"
4042
SymbolicUtils = "4"
@@ -44,6 +46,7 @@ julia = "1.10"
4446
[extras]
4547
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
4648
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
49+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
4750
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
4851
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4952
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/DynamicExpressionsOptimExt.jl

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using DynamicExpressions:
99
set_scalar_constants!,
1010
get_number_type
1111

12-
import Optim: Optim, OptimizationResults, NLSolversBase
12+
import Optim: Optim, OptimizationResults
13+
using NLSolversBase: NLSolversBase
1314

1415
#! format: off
1516
"""
@@ -38,41 +39,135 @@ function Optim.minimizer(r::ExpressionOptimizationResults)
3839
end
3940

4041
"""Wrap function or objective with insertion of values of the constant nodes."""
41-
function wrap_func(
42+
@inline function _wrap_objective_x_last(
43+
::Nothing, tree::N, refs
44+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
45+
return nothing
46+
end
47+
@inline function _wrap_objective_x_last(
4248
f::F, tree::N, refs
4349
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
4450
function wrapped_f(args::Vararg{Any,M}) where {M}
45-
first_args = args[begin:(end - 1)]
46-
x = args[end]
51+
x = args[M]
4752
set_scalar_constants!(tree, x, refs)
48-
return @inline(f(first_args..., tree))
53+
newargs = Base.setindex(args, tree, M)
54+
return @inline(f(newargs...))
4955
end
50-
# without first args, it looks like this
51-
# function wrapped_f(x)
52-
# set_scalar_constants!(tree, x, refs)
53-
# return @inline(f(tree))
54-
# end
5556
return wrapped_f
5657
end
58+
59+
@inline function _wrap_objective_xv_tail(
60+
::Nothing, tree::N, refs
61+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
62+
return nothing
63+
end
64+
@inline function _wrap_objective_xv_tail(
65+
f::F, tree::N, refs
66+
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
67+
function wrapped_f(args::Vararg{Any,M}) where {M}
68+
if M < 2
69+
throw(
70+
ArgumentError(
71+
"Expected at least 2 arguments for objective functions of the form (..., x, v).",
72+
),
73+
)
74+
end
75+
x = args[M - 1]
76+
set_scalar_constants!(tree, x, refs)
77+
newargs = Base.setindex(args, tree, M - 1)
78+
return @inline(f(newargs...))
79+
end
80+
return wrapped_f
81+
end
82+
83+
function wrap_func(
84+
f::F, tree::N, refs
85+
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
86+
return _wrap_objective_x_last(f, tree, refs)
87+
end
5788
function wrap_func(
5889
::Nothing, tree::N, refs
5990
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
6091
return nothing
6192
end
93+
94+
# `NLSolversBase.InplaceObjective` is an internal type whose field layout changed
95+
# between NLSolversBase versions (and therefore between Optim majors).
96+
#
97+
# This extension supports:
98+
# - Optim v1.x (NLSolversBase v7.x): df, fdf, fgh, hv, fghv
99+
# - Optim v2.x (NLSolversBase v8.x): fdf, fgh, hvp, fghvp, fjvp
100+
#
101+
# We store the fields both as symbols (for runtime layout checks) and as `Val`s
102+
# (so the wrapper construction is type-stable and can compile-in the field set).
103+
const _INPLACEOBJECTIVE_SPEC_V8 = (
104+
field_syms=(:fdf, :fgh, :hvp, :fghvp, :fjvp),
105+
fields=(Val(:fdf), Val(:fgh), Val(:hvp), Val(:fghvp), Val(:fjvp)),
106+
x_last=(Val(:fdf), Val(:fgh)),
107+
xv_tail=(Val(:hvp), Val(:fghvp), Val(:fjvp)),
108+
)
109+
const _INPLACEOBJECTIVE_SPEC_V7 = (
110+
field_syms=(:df, :fdf, :fgh, :hv, :fghv),
111+
fields=(Val(:df), Val(:fdf), Val(:fgh), Val(:hv), Val(:fghv)),
112+
x_last=(Val(:df), Val(:fdf), Val(:fgh)),
113+
xv_tail=(Val(:hv), Val(:fghv)),
114+
)
115+
116+
@inline function _wrap_inplaceobjective_field(
117+
v_field::Val{field}, f::NLSolversBase.InplaceObjective, tree::N, refs, spec
118+
) where {field,N<:Union{AbstractExpressionNode,AbstractExpression}}
119+
if v_field in spec.x_last
120+
return _wrap_objective_x_last(getfield(f, field), tree, refs)
121+
elseif v_field in spec.xv_tail
122+
return _wrap_objective_xv_tail(getfield(f, field), tree, refs)
123+
else
124+
throw(
125+
ArgumentError(
126+
"Internal error: no wrapping rule for InplaceObjective field $(field). " *
127+
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
128+
),
129+
)
130+
end
131+
end
132+
133+
@inline function _wrap_inplaceobjective(
134+
f::NLSolversBase.InplaceObjective, tree::N, refs, spec
135+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
136+
wrapped = map(spec.fields) do v_field
137+
_wrap_inplaceobjective_field(v_field, f, tree, refs, spec)
138+
end
139+
return NLSolversBase.InplaceObjective(wrapped...)
140+
end
141+
62142
function wrap_func(
63143
f::NLSolversBase.InplaceObjective, tree::N, refs
64144
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
65-
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
145+
# Some objectives, like `only_fg!(fg!)`, are not functions but instead
66146
# `InplaceObjective`. These contain multiple functions, each of which needs to be
67147
# wrapped. Some functions are `nothing`; those can be left as-is.
68-
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
69-
return NLSolversBase.InplaceObjective(
70-
wrap_func(f.df, tree, refs),
71-
wrap_func(f.fdf, tree, refs),
72-
wrap_func(f.fgh, tree, refs),
73-
wrap_func(f.hv, tree, refs),
74-
wrap_func(f.fghv, tree, refs),
75-
)
148+
#
149+
# We use `@static` branching so that only the relevant layout for the *installed*
150+
# NLSolversBase version is compiled/instrumented.
151+
@static if fieldnames(NLSolversBase.InplaceObjective) ==
152+
_INPLACEOBJECTIVE_SPEC_V8.field_syms
153+
# NLSolversBase v8 / Optim v2
154+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
155+
elseif fieldnames(NLSolversBase.InplaceObjective) ==
156+
_INPLACEOBJECTIVE_SPEC_V7.field_syms
157+
# NLSolversBase v7 / Optim v1
158+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
159+
else
160+
# LCOV_EXCL_START
161+
fields = fieldnames(NLSolversBase.InplaceObjective)
162+
throw(
163+
ArgumentError(
164+
"Unsupported NLSolversBase.InplaceObjective field layout: $(fields). " *
165+
"This extension supports layouts used by NLSolversBase v7 (Optim v1) and v8 (Optim v2). " *
166+
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
167+
),
168+
)
169+
# LCOV_EXCL_END
170+
end
76171
end
77172

78173
"""

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
1111
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
14+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
1415
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1516
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/runtests.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
using SafeTestsets
22
using TestItemRunner
33

4-
# Check if SR_ENZYME_TEST is set in env
5-
test_name = split(get(ENV, "SR_TEST", "main"), ",")
4+
# Control which test groups run.
5+
#
6+
# Accepts a comma-separated list in SR_TEST (default: "main").
7+
#
8+
# - "main": full test suite (testitems)
9+
# - "optim": Optim-specific testitems only
10+
# - "jet": JET analysis
11+
# - "enzyme": Enzyme tests
612

7-
unknown_tests = filter(Base.Fix2(, ["enzyme", "jet", "main"]), test_name)
13+
test_names = split(get(ENV, "SR_TEST", "main"), ",")
14+
15+
allowed = ["enzyme", "jet", "main", "optim"]
16+
unknown_tests = filter(Base.Fix2(, allowed), test_names)
817

918
if !isempty(unknown_tests)
1019
error("Unknown test names: $unknown_tests")
1120
end
1221

13-
if "enzyme" in test_name
22+
if "enzyme" in test_names
1423
@safetestset "Test enzyme derivatives" begin
1524
include("test_enzyme.jl")
1625
end
1726
end
18-
if "jet" in test_name
27+
28+
if "jet" in test_names
1929
@safetestset "JET" begin
2030
using Preferences
2131
set_preferences!("DynamicExpressions", "instability_check" => "disable"; force=true)
@@ -54,7 +64,21 @@ if "jet" in test_name
5464
end
5565
end
5666
end
57-
if "main" in test_name
58-
include("unittest.jl")
59-
@run_package_tests
67+
68+
# TestItemRunner's `@run_package_tests` scans *all* `.jl` files under the package root,
69+
# so we must filter to only the testitem files we actually want to run.
70+
71+
testitem_suffixes = String[]
72+
73+
if "main" in test_names
74+
push!(testitem_suffixes, joinpath("test", "unittest.jl"))
75+
push!(testitem_suffixes, joinpath("test", "test_optim.jl"))
76+
end
77+
if "optim" in test_names
78+
push!(testitem_suffixes, joinpath("test", "test_optim.jl"))
79+
end
80+
81+
if !isempty(testitem_suffixes)
82+
@run_package_tests filter =
83+
ti -> any(suf -> endswith(ti.filename, suf), testitem_suffixes)
6084
end

0 commit comments

Comments
 (0)