|
97 | 97 | res, st_new = @jit model(x, ps, st) |
98 | 98 | @test st_new.rng isa Reactant.ReactantRNG |
99 | 99 | end |
| 100 | + |
100 | 101 | @testset "Lux Parameter JVP with Bias" begin |
101 | | - noisy = Reactant.TestUtils.construct_test_array(Float32, 5, 2) |
| 102 | + input = Reactant.TestUtils.construct_test_array(Float32, 5, 2) |
102 | 103 | model = Dense(5 => 4; use_bias=true) |
103 | 104 | ps, st = Lux.setup(Xoshiro(0), model) |
104 | | - cnoisy = Reactant.to_rarray(noisy) |
| 105 | + cinput = Reactant.to_rarray(input) |
105 | 106 | cps = Reactant.to_rarray(ps) |
106 | 107 | cst = Reactant.to_rarray(st) |
107 | 108 |
|
108 | | - sm_input = Lux.StatefulLuxLayer{true}(model, cps, cst) |
109 | | - f(p) = first(model(cnoisy, p, cst)) |
| 109 | + sm = Lux.StatefulLuxLayer{true}(model, cps, cst) |
| 110 | + sm_input = Base.Fix1(sm, cinput) |
110 | 111 |
|
111 | | - # test jvp |
112 | | - jvp = @jit jacobian_vector_product(f, AutoEnzyme(), cps, cps) |
113 | | - @test jvp isa NamedTuple |
| 112 | + jvp = @jit jacobian_vector_product(sm_input, AutoEnzyme(), cps, cps) |
| 113 | + @test jvp isa Reactant.ConcreteRArray |
114 | 114 | end |
0 commit comments