Skip to content

Commit 3c761da

Browse files
committed
test: update
1 parent c4db84b commit 3c761da

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

test/nn/lux.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,18 @@ end
9797
res, st_new = @jit model(x, ps, st)
9898
@test st_new.rng isa Reactant.ReactantRNG
9999
end
100+
100101
@testset "Lux Parameter JVP with Bias" begin
101-
noisy = Reactant.TestUtils.construct_test_array(Float32, 5, 2)
102+
input = Reactant.TestUtils.construct_test_array(Float32, 5, 2)
102103
model = Dense(5 => 4; use_bias=true)
103104
ps, st = Lux.setup(Xoshiro(0), model)
104-
cnoisy = Reactant.to_rarray(noisy)
105+
cinput = Reactant.to_rarray(input)
105106
cps = Reactant.to_rarray(ps)
106107
cst = Reactant.to_rarray(st)
107108

108-
sm_input = Lux.StatefulLuxLayer{true}(model, cps, cst)
109-
f(p) = first(model(cnoisy, p, cst))
109+
sm = Lux.StatefulLuxLayer{true}(model, cps, cst)
110+
sm_input = Base.Fix1(sm, cinput)
110111

111-
# test jvp
112-
jvp = @jit jacobian_vector_product(f, AutoEnzyme(), cps, cps)
113-
@test jvp isa NamedTuple
112+
jvp = @jit jacobian_vector_product(sm_input, AutoEnzyme(), cps, cps)
113+
@test jvp isa Reactant.ConcreteRArray
114114
end

0 commit comments

Comments
 (0)