Skip to content

Commit cf4b06d

Browse files
authored
Support in-place rmul/lmul for Diagonals (#2856)
1 parent 7b1a989 commit cf4b06d

2 files changed

Lines changed: 22 additions & 2 deletions

File tree

lib/cublas/linalg.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# interfacing with LinearAlgebra standard library
22

3-
using LinearAlgebra: MulAddMul, AdjOrTrans, wrap, UpperOrLowerTriangular
3+
using LinearAlgebra: MulAddMul, AdjOrTrans, wrap, UpperOrLowerTriangular, rmul!, lmul!
44

55
#
66
# BLAS 1
@@ -413,6 +413,14 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoin
413413
return C
414414
end
415415

416+
function LinearAlgebra.lmul!(A::Diagonal{T,<:CuVector{T}}, B::CuMatrix{T}) where {T<:CublasFloat}
417+
return dgmm!('L', B, A.diag, B)
418+
end
419+
420+
function LinearAlgebra.rmul!(A::CuMatrix{T}, B::Diagonal{T,<:CuVector{T}}) where {T<:CublasFloat}
421+
return dgmm!('R', A, B.diag, A)
422+
end
423+
416424
# diagm
417425

418426
LinearAlgebra.diagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm(nothing, kv...)

test/libraries/cublas/extensions.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,18 +531,30 @@ k = 13
531531
h_C = Array(d_C)
532532
@test C h_C
533533
end
534-
@testset "diagonal -- mul!" begin
534+
@testset "diagonal -- mul!, rmul!, lmul!" begin
535535
XA = rand(elty,m,n)
536536
d_XA = CuArray(XA)
537537
d_X = Diagonal(d_x)
538538
mul!(d_XA, d_X, d_A)
539539
Array(d_XA) Diagonal(x) * A
540+
541+
XA = rand(elty,m,n)
542+
d_XA = CuArray(XA)
543+
d_X = Diagonal(d_x)
544+
lmul!(d_X, d_XA)
545+
Array(d_XA) Diagonal(x) * XA
540546

541547
AY = rand(elty,m,n)
542548
d_AY = CuArray(AY)
543549
d_Y = Diagonal(d_y)
544550
mul!(d_AY, d_A, d_Y)
545551
Array(d_AY) A * Diagonal(y)
552+
553+
AY = rand(elty,m,n)
554+
d_AY = CuArray(AY)
555+
d_Y = Diagonal(d_y)
556+
rmul!(d_AY, d_Y)
557+
Array(d_AY) AY * Diagonal(y)
546558

547559
YA = rand(elty,n,m)
548560
d_YA = CuArray(YA)

0 commit comments

Comments
 (0)