Skip to content

Commit 9745190

Browse files
committed
More tests and bugfixing for CUSPARSE
1 parent 1bcf495 commit 9745190

4 files changed

Lines changed: 27 additions & 6 deletions

File tree

lib/cusparse/conversions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ end
587587

588588
function CuSparseMatrixCSR{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv}
589589
m,n = size(coo)
590+
nnz(coo) == 0 && return CuSparseMatrixCSR{Tv}(CUDA.ones(Cint, m+1), coo.colInd, nonzeros(coo), size(coo))
590591
coo = sort_coo(coo, 'R')
591592
csrRowPtr = CuVector{Cint}(undef, m+1)
592593
cusparseXcoo2csr(handle(), coo.rowInd, nnz(coo), m, csrRowPtr, index)
@@ -595,6 +596,7 @@ end
595596

596597
function CuSparseMatrixCOO{Tv}(csr::CuSparseMatrixCSR{Tv}; index::SparseChar='O') where {Tv}
597598
m,n = size(csr)
599+
nnz(csr) == 0 && return CuSparseMatrixCOO{Tv}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csr), size(csr))
598600
cooRowInd = CuVector{Cint}(undef, nnz(csr))
599601
cusparseXcsr2coo(handle(), csr.rowPtr, nnz(csr), m, cooRowInd, index)
600602
CuSparseMatrixCOO{Tv}(cooRowInd, csr.colVal, nonzeros(csr), size(csr))
@@ -604,6 +606,7 @@ end
604606

605607
function CuSparseMatrixCSC{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv}
606608
m,n = size(coo)
609+
nnz(coo) == 0 && return CuSparseMatrixCSC{Tv}(CUDA.ones(Cint, n+1), coo.rowInd, nonzeros(coo), size(coo))
607610
coo = sort_coo(coo, 'C')
608611
cscColPtr = CuVector{Cint}(undef, n+1)
609612
cusparseXcoo2csr(handle(), coo.colInd, nnz(coo), n, cscColPtr, index)
@@ -612,6 +615,7 @@ end
612615

613616
function CuSparseMatrixCOO{Tv}(csc::CuSparseMatrixCSC{Tv}; index::SparseChar='O') where {Tv}
614617
m,n = size(csc)
618+
nnz(csc) == 0 && return CuSparseMatrixCOO{Tv}(CUDA.zeros(Cint, 0), CUDA.zeros(Cint, 0), nonzeros(csc), size(csc))
615619
cooColInd = CuVector{Cint}(undef, nnz(csc))
616620
cusparseXcsr2coo(handle(), csc.colPtr, nnz(csc), n, cooColInd, index)
617621
coo = CuSparseMatrixCOO{Tv}(csc.rowVal, cooColInd, nonzeros(csc), size(csc))

lib/cusparse/linalg.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ function Base.reshape(A::CuSparseMatrixCOO, dims::Dims)
5757
sparse(new_row, new_col, A.nzVal, dims[1], length(dims) == 1 ? 1 : dims[2], fmt = :coo)
5858
end
5959

60-
function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}) where {T}
60+
function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::CuSparseMatrixCOO{T, Ti}) where {Ti, T}
6161
mA,nA = size(A)
6262
mB,nB = size(B)
6363
out_shape = (mA * mB, nA * nB)
6464
Annz = Int64(A.nnz)
6565
Bnnz = Int64(B.nnz)
6666

6767
if Annz == 0 || Bnnz == 0
68-
return CuSparseMatrixCOO(CuVector{T}(undef, 0), CuVector{T}(undef, 0), CuVector{T}(undef, 0), out_shape)
68+
return CuSparseMatrixCOO(CuVector{Ti}(undef, 0), CuVector{Ti}(undef, 0), CuVector{T}(undef, 0), out_shape)
6969
end
7070

7171
row = (A.rowInd .- 1) .* mB
@@ -82,15 +82,15 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}) wh
8282
sparse(row, col, data, out_shape..., fmt = :coo)
8383
end
8484

85-
function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::Diagonal) where {T}
85+
function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::Diagonal) where {Ti, T}
8686
mA,nA = size(A)
8787
mB,nB = size(B)
8888
out_shape = (mA * mB, nA * nB)
8989
Annz = Int64(A.nnz)
9090
Bnnz = nB
9191

9292
if Annz == 0 || Bnnz == 0
93-
return CuSparseMatrixCOO(CuVector{T}(undef, 0), CuVector{T}(undef, 0), CuVector{T}(undef, 0), out_shape)
93+
return CuSparseMatrixCOO(CuVector{Ti}(undef, 0), CuVector{Ti}(undef, 0), CuVector{T}(undef, 0), out_shape)
9494
end
9595

9696
row = (A.rowInd .- 1) .* mB
@@ -107,15 +107,15 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::Diagonal) where {T}
107107
sparse(row, col, data, out_shape..., fmt = :coo)
108108
end
109109

110-
function LinearAlgebra.kron(A::Diagonal, B::CuSparseMatrixCOO{T}) where {T}
110+
function LinearAlgebra.kron(A::Diagonal, B::CuSparseMatrixCOO{T, Ti}) where {Ti, T}
111111
mA,nA = size(A)
112112
mB,nB = size(B)
113113
out_shape = (mA * mB, nA * nB)
114114
Annz = nA
115115
Bnnz = Int64(B.nnz)
116116

117117
if Annz == 0 || Bnnz == 0
118-
return CuSparseMatrixCOO(CuVector{T}(undef, 0), CuVector{T}(undef, 0), CuVector{T}(undef, 0), out_shape)
118+
return CuSparseMatrixCOO(CuVector{Ti}(undef, 0), CuVector{Ti}(undef, 0), CuVector{T}(undef, 0), out_shape)
119119
end
120120

121121
row = (0:nA-1) .* mB

test/libraries/cusparse.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,15 @@ blockdim = 5
196196
d_x = UpperTriangular(CuSparseMatrixCSC(x))
197197
@test istriu(d_x)
198198
@test !istril(d_x)
199+
d_x = UpperTriangular(transpose(CuSparseMatrixCSC(x)))
200+
@test istriu(d_x)
201+
@test !istril(d_x)
199202
d_x = LowerTriangular(CuSparseMatrixCSC(x))
200203
@test !istriu(d_x)
201204
@test istril(d_x)
205+
d_x = LowerTriangular(transpose(CuSparseMatrixCSC(x)))
206+
@test !istriu(d_x)
207+
@test istril(d_x)
202208
d_x = UpperTriangular(triu(transpose(CuSparseMatrixCSC(x)), 1))
203209
@test istriu(d_x)
204210
@test !istril(d_x)

test/libraries/cusparse/linalg.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,23 @@ end
4343
b = sprand(ComplexF32, 100, 100, 0.1)
4444
A = typ(a)
4545
B = typ(b)
46+
za = spzeros(ComplexF32, 100, 100)
47+
ZA = typ(za)
4648
@test collect(kron(A, B)) kron(a, b)
4749
@test collect(kron(transpose(A), B)) kron(transpose(a), b)
4850
@test collect(kron(A, transpose(B))) kron(a, transpose(b))
4951
@test collect(kron(transpose(A), transpose(B))) kron(transpose(a), transpose(b))
5052
@test collect(kron(A', B)) kron(a', b)
5153
@test collect(kron(A, B')) kron(a, b')
5254
@test collect(kron(A', B')) kron(a', b')
55+
56+
@test collect(kron(ZA, B)) kron(za, b)
57+
@test collect(kron(transpose(ZA), B)) kron(transpose(za), b)
58+
@test collect(kron(ZA, transpose(B))) kron(za, transpose(b))
59+
@test collect(kron(transpose(ZA), transpose(B))) kron(transpose(za), transpose(b))
60+
@test collect(kron(ZA', B)) kron(za', b)
61+
@test collect(kron(ZA, B')) kron(za, b')
62+
@test collect(kron(ZA', B')) kron(za', b')
5363

5464
C = I(50)
5565
@test collect(kron(A, C)) kron(a, C)
@@ -88,4 +98,5 @@ end
8898
A2 = typ(A)
8999

90100
@test dot(x, A, y) dot(x2, A2, y2)
101+
@test_throws DimensionMismatch("dimensions must match") dot(CUDA.rand(elty, N1+1), A2, y2)
91102
end

0 commit comments

Comments
 (0)