Skip to content

Commit 506a724

Browse files
committed
fix: temp patch for syrk lowering
1 parent c6e3bbd commit 506a724

3 files changed

Lines changed: 9 additions & 11 deletions

File tree

src/Ops.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4215,20 +4215,22 @@ end
42154215
alpha::Union{TracedRNumber{T},T},
42164216
beta::Union{TracedRNumber{T},T};
42174217
uplo::Char,
4218+
output_uplo::Char,
42184219
transpose_a::Char,
42194220
location=mlir_stacktrace("syrk", @__FILE__, @__LINE__),
42204221
) where {T,N}
42214222
ctx = MLIR.IR.current_context()
4222-
uplo_attr = MLIR.API.enzymexlaLapackUploAttrGet(ctx, LAPACK_UPLO_MAP[uplo])
42234223

42244224
res = MLIR.IR.result(
42254225
enzymexla.blas_syrk(
42264226
A.mlir_data,
42274227
C.mlir_data,
42284228
constant(alpha; location).mlir_data,
42294229
constant(beta; location).mlir_data;
4230-
uplo=uplo_attr,
4231-
output_uplo=uplo_attr,
4230+
uplo=MLIR.API.enzymexlaLapackUploAttrGet(ctx, LAPACK_UPLO_MAP[uplo]),
4231+
output_uplo=MLIR.API.enzymexlaLapackUploAttrGet(
4232+
ctx, LAPACK_UPLO_MAP[output_uplo]
4233+
),
42324234
transpose=MLIR.API.enzymexlaLapackTransposeAttrGet(
42334235
ctx, LAPACK_TRANSPOSE_MAP[transpose_a]
42344236
),

src/stdlibs/BLAS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ function BLAS.syrk!(
523523
alpha,
524524
beta;
525525
uplo=uplo,
526+
output_uplo='F', # TODO: this is a temporary patch. Why is the default flipped?
526527
transpose_a=trans,
527528
)
528529
copyto!(C, res)

test/integration/blas.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using LinearAlgebra, Reactant, Test
22
using LinearAlgebra: BLAS
33

4-
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
5-
64
@testset "Level 1" begin
75
@testset "asum" begin
86
x = Reactant.TestUtils.construct_test_array(Float32, 32)
@@ -284,8 +282,7 @@ end
284282
@jit BLAS.syrk!('U', 'N', 2.0f0, A_ra, 3.0f0, C_ra)
285283
C_target = copy(C)
286284
BLAS.syrk!('U', 'N', 2.0f0, A, 3.0f0, C_target)
287-
@test UpperTriangular(C_ra) UpperTriangular(C_target) atol = 1e-3 rtol = 1e-3 broken =
288-
!RunningOnTPU
285+
@test UpperTriangular(C_ra) UpperTriangular(C_target) atol = 1e-3 rtol = 1e-3
289286

290287
# test 'L' and 'T'
291288
A2 = Reactant.TestUtils.construct_test_array(Float32, 16, 16)
@@ -296,8 +293,7 @@ end
296293
@jit BLAS.syrk!('L', 'T', 2.0f0, A2_ra, 3.0f0, C2_ra)
297294
C2_target = copy(C2)
298295
BLAS.syrk!('L', 'T', 2.0f0, A2, 3.0f0, C2_target)
299-
@test LowerTriangular(C2_ra) LowerTriangular(C2_target) atol = 1e-3 rtol = 1e-3 broken =
300-
!RunningOnTPU
296+
@test LowerTriangular(C2_ra) LowerTriangular(C2_target) atol = 1e-3 rtol = 1e-3
301297
end
302298

303299
if isdefined(BLAS, :gemmt!)
@@ -415,8 +411,7 @@ end
415411
BLAS.trsm('L', 'U', 'N', 'N', 2.0f0, Ainv, B) atol = 1e-3 rtol = 1e-3
416412

417413
@test UpperTriangular(@jit(BLAS.syrk('U', 'N', 2.0f0, A_ra)))
418-
UpperTriangular(BLAS.syrk('U', 'N', 2.0f0, A)) atol = 1e-3 rtol = 1e-3 broken =
419-
!RunningOnTPU
414+
UpperTriangular(BLAS.syrk('U', 'N', 2.0f0, A)) atol = 1e-3 rtol = 1e-3
420415
@test UpperTriangular(@jit(BLAS.syr2k('U', 'N', 2.0f0, A_ra, B_ra)))
421416
UpperTriangular(BLAS.syr2k('U', 'N', 2.0f0, A, B)) atol = 1e-3 rtol = 1e-3
422417

0 commit comments

Comments
 (0)