Skip to content

Commit 06df40e

Browse files
committed
test: raising to syrk
1 parent 0fb46c9 commit 06df40e

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

src/stdlibs/LinearAlgebra.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ function __init__()
4646
(BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_),
4747
(BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_),
4848
(BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_),
49+
# syrk
50+
(BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_),
51+
(BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_),
52+
(BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_),
53+
(BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_),
54+
# trmm
55+
(BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_),
56+
(BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_),
57+
(BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_),
58+
(BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_),
59+
# symm
60+
(BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_),
61+
(BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_),
62+
(BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_),
63+
(BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_),
4964
]
5065
sym = Libdl.dlsym(libblastrampoline_handle, cname)
5166
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(

test/integration/linear_algebra.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,24 @@ end
723723
@jit LinearAlgebra.normalize!(x_ra)
724724
@test x_ra x
725725
end
726+
727+
raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y
728+
raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y
729+
730+
@testset "syrk optimizations" begin
731+
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
732+
x = Reactant.TestUtils.construct_test_array(elty, 4, 5)
733+
y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4)
734+
y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5)
735+
x_ra = Reactant.to_rarray(x)
736+
737+
@testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2))
738+
y_ra = Reactant.to_rarray(y)
739+
740+
hlo = @code_hlo optimize=:before_jit fn(x_ra, y_ra)
741+
@test occursin("enzymexla.blas.syrk", repr(hlo))
742+
743+
@test @jit(fn(x_ra, y_ra)) fn(x, y)
744+
end
745+
end
746+
end

0 commit comments

Comments
 (0)