Skip to content

Commit 9b67c0d

Browse files
committed
More tests and bugfixes for CUSOLVER
1 parent 6180d2c commit 9b67c0d

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

lib/cusolver/dense_generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}) where {T <: BlasFloat}
122122
end
123123

124124
# Xsytrs
125-
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
125+
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
126126
chkuplo(uplo)
127127
n = checksquare(A)
128128
nrhs = size(B, 2)
@@ -149,7 +149,7 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride
149149
B
150150
end
151151

152-
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
152+
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) where {T <: BlasFloat}
153153
chkuplo(uplo)
154154
n = checksquare(A)
155155
nrhs = size(B, 2)

lib/cusolver/linalg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
6666
end
6767

6868
function Base.:\(_A::Symmetric{<:Any,<:CuMatOrAdj}, _B::CuOrAdj)
69-
uplo = A.uplo
69+
uplo = _A.uplo
7070
A, B = copy_cublasfloat(_A.data, _B)
7171

7272
# LDLᴴ decomposition with partial pivoting
@@ -371,9 +371,9 @@ for (triangle, uplo, diag) in ((:LowerTriangular, 'L', 'N'),
371371
@eval begin
372372
function LinearAlgebra.inv(A::$triangle{T,<:StridedCuMatrix{T}}) where T <: BlasFloat
373373
n = checksquare(A)
374-
B = copy(A.data)
375-
trtri!(uplo, diag, B)
376-
return B
374+
B = copy(parent(A))
375+
trtri!($uplo, $diag, B)
376+
return $triangle(B)
377377
end
378378
end
379379
end

test/libraries/cusolver/dense.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ k = 1
8989
A = rand(elty,n,n)
9090
A = uplo == 'L' ? tril(A) : triu(A)
9191
A = diag == 'N' ? A : A - Diagonal(A) + I
92-
dA = triangle(CuArray(A))
92+
dA = triangle(view(CuArray(A), 1:2:n, 1:2:n)) # without this view, we are hitting the CUBLAS method!
9393
dA⁻¹ = inv(dA)
94-
dI = dA.data * dA⁻¹
94+
dI = CuArray(dA) * CuArray(dA⁻¹)
9595
@test Array(dI) I
9696
end
9797
end
@@ -265,6 +265,8 @@ k = 1
265265
A += A'
266266
d_A = CuArray(A)
267267
Eig = eigen(LinearAlgebra.Hermitian(A))
268+
d_eig = eigen(d_A)
269+
@test Eig.values collect(d_eig.values)
268270
d_eig = eigen(LinearAlgebra.Hermitian(d_A))
269271
@test Eig.values collect(d_eig.values)
270272
h_V = collect(d_eig.vectors)
@@ -533,6 +535,7 @@ k = 1
533535
B = rand(elty, n)
534536
d_B = CuArray(B)
535537
@test collect(d_M \ d_B) M \ B
538+
@test_throws DimensionMismatch("arguments must have the same number of rows") d_M \ CUDA.ones(elty, n+1)
536539
A = rand(elty, m, n) # A is a matrix and B,C is a vector
537540
d_A = CuArray(A)
538541
M = qr(A)
@@ -782,7 +785,9 @@ end
782785
Bf = cublasfloat.(B)
783786
bf = cublasfloat.(b)
784787
@test Array(d_A \ d_B) (Af \ Bf)
788+
@test Array(Symmetric(d_A) \ d_B) (Af \ Bf)
785789
@test Array(d_A \ d_b) (Af \ bf)
790+
@test Array(Symmetric(d_A) \ d_b) (Af \ bf)
786791
@inferred d_A \ d_B
787792
@inferred d_A \ d_b
788793
end

0 commit comments

Comments
 (0)