Skip to content

Commit 75ca0b3

Browse files
committed
fix: dont accidentally raise after fallback lowering
1 parent f3d54a2 commit 75ca0b3

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

src/Compiler.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ function optimization_passes(
704704
lower_comms::Bool=true,
705705
backend::String="gpu",
706706
is_sharded::Bool=false,
707+
raise_shlo_to_blas_lapack::Bool=true,
707708
)
708709
(; max_constant_threshold) = compile_options
709710

@@ -918,7 +919,9 @@ function optimization_passes(
918919

919920
if !is_sharded
920921
# these passes don't have optimized sharding implementations
921-
append!(transform_passes_list, ["dot_general_to_syrk"])
922+
if raise_shlo_to_blas_lapack
923+
append!(transform_passes_list, ["dot_general_to_syrk"])
924+
end
922925
end
923926

924927
if !compile_options.disable_auto_batching_passes
@@ -2026,6 +2029,7 @@ function compile_mlir!(
20262029
lower_comms,
20272030
backend,
20282031
is_sharded,
2032+
raise_shlo_to_blas_lapack=false,
20292033
),
20302034
"post_op_transpose_reshape",
20312035
)
@@ -2168,7 +2172,15 @@ function compile_mlir!(
21682172
run_pass_pipeline!(
21692173
mod,
21702174
join(
2171-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2175+
[
2176+
opt_passes,
2177+
"canonicalize",
2178+
"cse",
2179+
"canonicalize",
2180+
opt_passes2,
2181+
lower_enzymexla_linalg_pass,
2182+
jit,
2183+
],
21722184
",",
21732185
),
21742186
"mid_pad_opts",

0 commit comments

Comments
 (0)