Skip to content

Commit cdb1e18

Browse files
authored
More tests and bugfixes for CUSOLVER (#2707)
* More tests and bugfixes for CUSOLVER * Add a test for sytrs
1 parent f7465ff commit cdb1e18

4 files changed

Lines changed: 19 additions & 8 deletions

File tree

lib/cusolver/dense_generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
122122
end
123123

124124
# Xsytrs
125-
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
125+
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
126126
chkuplo(uplo)
127127
n = checksquare(A)
128128
nrhs = size(B, 2)
@@ -149,7 +149,7 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride
149149
B
150150
end
151151

152-
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
152+
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
153153
chkuplo(uplo)
154154
n = checksquare(A)
155155
nrhs = size(B, 2)

lib/cusolver/linalg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
6666
end
6767

6868
function Base.:\(_A::Symmetric{<:Any,<:CuMatOrAdj}, _B::CuOrAdj)
69-
uplo = A.uplo
69+
uplo = _A.uplo
7070
A, B = copy_cublasfloat(_A.data, _B)
7171

7272
# LDLᴴ decomposition with partial pivoting
@@ -371,9 +371,9 @@ for (triangle, uplo, diag) in ((:LowerTriangular, 'L', 'N'),
371371
@eval begin
372372
function LinearAlgebra.inv(A::$triangle{T,<:StridedCuMatrix{T}}) where T <: BlasFloat
373373
n = checksquare(A)
374-
B = copy(A.data)
375-
trtri!(uplo, diag, B)
376-
return B
374+
B = copy(parent(A))
375+
trtri!($uplo, $diag, B)
376+
return $triangle(B)
377377
end
378378
end
379379
end

test/libraries/cusolver/dense.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ k = 1
8989
A = rand(elty,n,n)
9090
A = uplo == 'L' ? tril(A) : triu(A)
9191
A = diag == 'N' ? A : A - Diagonal(A) + I
92-
dA = triangle(CuArray(A))
92+
dA = triangle(view(CuArray(A), 1:2:n, 1:2:n)) # without this view, we are hitting the CUBLAS method!
9393
dA⁻¹ = inv(dA)
94-
dI = dA.data * dA⁻¹
94+
dI = CuArray(dA) * CuArray(dA⁻¹)
9595
@test Array(dI) I
9696
end
9797
end
@@ -265,6 +265,8 @@ k = 1
265265
A += A'
266266
d_A = CuArray(A)
267267
Eig = eigen(LinearAlgebra.Hermitian(A))
268+
d_eig = eigen(d_A)
269+
@test Eig.values collect(d_eig.values)
268270
d_eig = eigen(LinearAlgebra.Hermitian(d_A))
269271
@test Eig.values collect(d_eig.values)
270272
h_V = collect(d_eig.vectors)
@@ -533,6 +535,7 @@ k = 1
533535
B = rand(elty, n)
534536
d_B = CuArray(B)
535537
@test collect(d_M \ d_B) M \ B
538+
@test_throws DimensionMismatch("arguments must have the same number of rows") d_M \ CUDA.ones(elty, n+1)
536539
A = rand(elty, m, n) # A is a matrix and B,C is a vector
537540
d_A = CuArray(A)
538541
M = qr(A)
@@ -782,7 +785,9 @@ end
782785
Bf = cublasfloat.(B)
783786
bf = cublasfloat.(b)
784787
@test Array(d_A \ d_B) (Af \ Bf)
788+
@test Array(Symmetric(d_A) \ d_B) (Af \ Bf)
785789
@test Array(d_A \ d_b) (Af \ bf)
790+
@test Array(Symmetric(d_A) \ d_b) (Af \ bf)
786791
@inferred d_A \ d_B
787792
@inferred d_A \ d_b
788793
end

test/libraries/cusolver/dense_generic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,27 @@ p = 5
9696
@testset "pivoting = $pivoting" for pivoting in (false, true)
9797
A = rand(elty,n,n)
9898
B = rand(elty,n,p)
99+
C = rand(elty,n)
99100
A = A + transpose(A)
100101
d_A = CuMatrix(A)
101102
d_B = CuMatrix(B)
103+
d_C = CuVector(C)
102104
!pivoting && (CUSOLVER.version() < v"11.7.2") && continue
103105
if pivoting
104106
d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting)
105107
d_ipiv = CuVector{Int64}(d_ipiv)
106108
CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B)
109+
CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_C)
107110
else
108111
d_A, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting)
109112
CUSOLVER.sytrs!(uplo, d_A, d_B)
113+
CUSOLVER.sytrs!(uplo, d_A, d_C)
110114
end
111115
A, ipiv, _ = LAPACK.sytrf!(uplo, A)
112116
LAPACK.sytrs!(uplo, A, ipiv, B)
117+
LAPACK.sytrs!(uplo, A, ipiv, C)
113118
@test B collect(d_B)
119+
@test C collect(d_C)
114120
end
115121
end
116122
end

0 commit comments

Comments
 (0)