Skip to content

Commit 734a729

Browse files
authored
Fix rmul for transpose/adjoint (#2871)
1 parent 863eb44 commit 734a729

2 files changed

Lines changed: 56 additions & 12 deletions

File tree

lib/cublas/linalg.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,21 @@ function LinearAlgebra.lmul!(A::Diagonal{Td,<:CuVector{Td}}, B::Adjoint{Tt, <:Cu
441441
return B
442442
end
443443
# eltypes do not match
444-
LinearAlgebra.rmul!(A::CuMatrix, B::Diagonal{T,<:CuVector{T}}) where {T<:CublasFloat} = lmul!(B, transpose(A))
445-
LinearAlgebra.rmul!(A::Transpose{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat} = lmul!(B, A)
446-
LinearAlgebra.rmul!(A::Adjoint{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat} = conj!(lmul!(B, conj!(A)))
444+
function LinearAlgebra.rmul!(A::CuMatrix, B::Diagonal{T,<:CuVector{T}}) where {T<:CublasFloat}
445+
At = transpose(A)
446+
@. At = B.diag * At
447+
return A
448+
end
449+
function LinearAlgebra.rmul!(A::Transpose{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat}
450+
At = parent(A)
451+
@. At = B.diag * At
452+
return transpose(At)
453+
end
454+
function LinearAlgebra.rmul!(A::Adjoint{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat}
455+
At = parent(A)
456+
@. At = adjoint(B.diag) * At
457+
return adjoint(At)
458+
end
447459

448460
# diagm
449461

test/libraries/cublas/extensions.jl

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -594,19 +594,51 @@ k = 13
594594
end # elty
595595

596596
@testset "rmul/lmul with mixed eltypes ($Tr, $Tc)" for (Tr, Tc) in ((Float32, ComplexF32), (Float64, ComplexF64))
597-
x = rand(Tr,m)
598-
d_x = CuArray(x)
599-
XA = rand(Tc,m,n)
597+
x = rand(Tr,m)
598+
d_x = CuArray(x)
599+
XA = rand(Tc,m,n)
600600
d_XA = CuArray(XA)
601-
d_X = Diagonal(d_x)
601+
d_X = Diagonal(d_x)
602602
lmul!(d_X, d_XA)
603603
@test Array(d_XA) Diagonal(x) * XA
604-
605-
y = rand(Tr,n)
606-
d_y = CuArray(y)
607-
AY = rand(Tc,m,n)
604+
605+
x = rand(Tr,m)
606+
d_x = CuArray(x)
607+
XA = rand(Tc,n,m)
608+
d_AX = transpose(CuArray(XA))
609+
d_X = Diagonal(d_x)
610+
lmul!(d_X, d_AX)
611+
@test Array(d_AX) Diagonal(x) * transpose(XA)
612+
613+
x = rand(Tr,m)
614+
d_x = CuArray(x)
615+
XA = rand(Tc,n,m)
616+
d_AX = adjoint(CuArray(XA))
617+
d_X = Diagonal(d_x)
618+
lmul!(d_X, d_AX)
619+
@test Array(d_AX) Diagonal(x) * adjoint(XA)
620+
621+
y = rand(Tr,n)
622+
d_y = CuArray(y)
623+
AY = rand(Tc,m,n)
608624
d_AY = CuArray(AY)
609-
d_Y = Diagonal(d_y)
625+
d_Y = Diagonal(d_y)
610626
rmul!(d_AY, d_Y)
611627
@test Array(d_AY) AY * Diagonal(y)
628+
629+
y = rand(Tr,n)
630+
d_y = CuArray(y)
631+
AY = rand(Tc,n,m)
632+
d_YA = transpose(CuArray(AY))
633+
d_Y = Diagonal(d_y)
634+
d_YA = rmul!(d_YA, d_Y)
635+
@test Array(d_YA) transpose(AY) * Diagonal(y)
636+
637+
y = rand(Tr,n)
638+
d_y = CuArray(y)
639+
AY = rand(Tc,n,m)
640+
d_YA = adjoint(CuArray(AY))
641+
d_Y = Diagonal(d_y)
642+
d_YA = rmul!(d_YA, d_Y)
643+
@test Array(d_YA) adjoint(AY) * Diagonal(y)
612644
end

0 commit comments

Comments
 (0)