Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
version = "0.2.180"
version = "0.2.181"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.265"
Reactant_jll = "0.0.266"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
32 changes: 29 additions & 3 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ function optimization_passes(
recognize_comms::Bool=true,
lower_comms::Bool=true,
backend::String="gpu",
is_sharded::Bool=false,
raise_shlo_to_blas_lapack::Bool=true,
)
(; max_constant_threshold) = compile_options

Expand Down Expand Up @@ -909,8 +911,19 @@ function optimization_passes(
"transpose_symmetric_simplify",
"divide_negated_operands_simplify",
"multiply_negated_operands_simplify",
"transpose_syrk_to_syrk",
"fuse_mul_into_syrk",
"fuse_add_into_syrk",
"factor_scalars_in_dot_general",
]

if !is_sharded
# these passes don't have optimized sharding implementations
if raise_shlo_to_blas_lapack
append!(transform_passes_list, ["dot_general_to_syrk"])
end
end

if !compile_options.disable_auto_batching_passes
append!(
transform_passes_list,
Expand Down Expand Up @@ -1693,10 +1706,10 @@ function compile_mlir!(
end

opt_passes = optimization_passes(
compile_options; sroa=true, recognize_comms, lower_comms, backend
compile_options; sroa=true, recognize_comms, lower_comms, backend, is_sharded
)
opt_passes2 = optimization_passes(
compile_options; sroa=false, recognize_comms, lower_comms, backend
compile_options; sroa=false, recognize_comms, lower_comms, backend, is_sharded
)

raise_passes = if raise isa String
Expand All @@ -1718,6 +1731,7 @@ function compile_mlir!(
recognize_comms,
lower_comms,
backend,
is_sharded,
)
result = result * "," * opt_passes3
end
Expand All @@ -1728,6 +1742,8 @@ function compile_mlir!(

blas_int_width = sizeof(BlasInt) * 8
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
blas_int_width=$blas_int_width},\
lower-enzymexla-blas{backend=$backend \
blas_int_width=$blas_int_width},\
lower-enzymexla-lapack{backend=$backend \
blas_int_width=$blas_int_width}"
Expand Down Expand Up @@ -2012,6 +2028,8 @@ function compile_mlir!(
recognize_comms,
lower_comms,
backend,
is_sharded,
raise_shlo_to_blas_lapack=false,
),
"post_op_transpose_reshape",
)
Expand Down Expand Up @@ -2154,7 +2172,15 @@ function compile_mlir!(
run_pass_pipeline!(
mod,
join(
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
[
opt_passes,
"canonicalize",
"cse",
"canonicalize",
opt_passes2,
lower_enzymexla_linalg_pass,
jit,
],
",",
),
"mid_pad_opts",
Expand Down
15 changes: 15 additions & 0 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ function __init__()
(BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_),
(BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_),
(BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_),
# syrk
(BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_),
(BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_),
(BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_),
(BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_),
# trmm
(BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_),
(BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_),
(BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_),
(BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_),
# symm
(BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_),
(BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_),
(BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_),
(BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_),
]
sym = Libdl.dlsym(libblastrampoline_handle, cname)
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
Expand Down
21 changes: 21 additions & 0 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,24 @@ end
@jit LinearAlgebra.normalize!(x_ra)
@test x_ra ≈ x
end

raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y
raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y

@testset "syrk optimizations" begin
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
x = Reactant.TestUtils.construct_test_array(elty, 4, 5)
y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4)
y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5)
x_ra = Reactant.to_rarray(x)

@testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2))
y_ra = Reactant.to_rarray(y)

hlo = @code_hlo optimize = :before_jit fn(x_ra, y_ra)
@test occursin("enzymexla.blas.syrk", repr(hlo))

@test @jit(fn(x_ra, y_ra)) ≈ fn(x, y)
end
end
end
Loading