@@ -33,6 +33,13 @@ const LAPACK_UPLO_MAP = Dict(
3333 ' F' => MLIR. API. ENZYMEXLA_LAPACK_UPLO_FULL,
3434)
3535
36+ const SVD_ALGORITHM_MAP = Dict (
37+ " DEFAULT" => MLIR. API. ENZYMEXLA_SVD_ALGORITHM_NONE,
38+ " QRIteration" => MLIR. API. ENZYMEXLA_SVD_ALGORITHM_QRITERATION,
39+ " Jacobi" => MLIR. API. ENZYMEXLA_SVD_ALGORITHM_JACOBI,
40+ " DivideAndConquer" => MLIR. API. ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER,
41+ )
42+
3643function _function_macro_error ()
3744 throw (ArgumentError (" `caller_function` is not available in this context" ))
3845end
@@ -3894,26 +3901,16 @@ end
38943901 Vt_size = (batch_sizes... , full ? n : r, n)
38953902 info_size = batch_sizes
38963903
3897- if algorithm == " DEFAULT"
3898- alg = MLIR. API. ENZYMEXLA_SVD_ALGORITHM_NONE
3899- elseif algorithm == " QRIteration"
3900- alg = MLIR. API. ENZYMEXLA_SVD_ALGORITHM_QRITERATION
3901- elseif algorithm == " DivideAndConquer"
3902- alg = MLIR. API. ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER
3903- elseif algorithm == " Jacobi"
3904- alg = MLIR. API. ENZYMEXLA_SVD_ALGORITHM_JACOBI
3905- else
3906- error (" Unsupported SVD algorithm: $algorithm " )
3907- end
3908-
39093904 svd_op = enzymexla. linalg_svd (
39103905 x. mlir_data;
39113906 U= mlir_type (TracedRArray{T,N}, U_size),
39123907 S= mlir_type (TracedRArray{Base. real (T),N - 1 }, S_size),
39133908 Vt= mlir_type (TracedRArray{T,N}, Vt_size),
39143909 info= mlir_type (TracedRArray{iT,N - 2 }, info_size),
39153910 full= full,
3916- algorithm= MLIR. API. enzymexlaSVDAlgorithmAttrGet (MLIR. IR. current_context (), alg),
3911+ algorithm= MLIR. API. enzymexlaSVDAlgorithmAttrGet (
3912+ MLIR. IR. current_context (), SVD_ALGORITHM_MAP[algorithm]
3913+ ),
39173914 location,
39183915 )
39193916
0 commit comments