Skip to content

Commit 6f87c1a

Browse files
committed
fix: more
1 parent c6bbc63 commit 6f87c1a

1 file changed

Lines changed: 10 additions & 13 deletions

File tree

src/Ops.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ const LAPACK_UPLO_MAP = Dict(
3333
'F' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_FULL,
3434
)
3535

36+
const SVD_ALGORITHM_MAP = Dict(
37+
"DEFAULT" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_NONE,
38+
"QRIteration" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_QRITERATION,
39+
"Jacobi" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_JACOBI,
40+
"DivideAndConquer" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER,
41+
)
42+
3643
function _function_macro_error()
3744
throw(ArgumentError("`caller_function` is not available in this context"))
3845
end
@@ -3894,26 +3901,16 @@ end
38943901
Vt_size = (batch_sizes..., full ? n : r, n)
38953902
info_size = batch_sizes
38963903

3897-
if algorithm == "DEFAULT"
3898-
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_NONE
3899-
elseif algorithm == "QRIteration"
3900-
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_QRITERATION
3901-
elseif algorithm == "DivideAndConquer"
3902-
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER
3903-
elseif algorithm == "Jacobi"
3904-
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_JACOBI
3905-
else
3906-
error("Unsupported SVD algorithm: $algorithm")
3907-
end
3908-
39093904
svd_op = enzymexla.linalg_svd(
39103905
x.mlir_data;
39113906
U=mlir_type(TracedRArray{T,N}, U_size),
39123907
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
39133908
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
39143909
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
39153910
full=full,
3916-
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.current_context(), alg),
3911+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(
3912+
MLIR.IR.current_context(), SVD_ALGORITHM_MAP[algorithm]
3913+
),
39173914
location,
39183915
)
39193916

0 commit comments

Comments
 (0)