Skip to content

Commit 591f61d

Browse files
authored
More tests and better error type for CUSPARSE generic (#2744)
1 parent 12caaf5 commit 591f61d

2 files changed

Lines changed: 19 additions & 3 deletions

File tree

lib/cusparse/generic.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function densetosparse(A::CuMatrix{T}, fmt::Symbol, index::SparseChar, algo::cus
5353
colPtr = CuVector{Cint}(undef, n+1)
5454
desc_sparse = CuSparseMatrixDescriptor(CuSparseMatrixCSC, colPtr, T, Cint, m, n, index)
5555
else
56-
error("Format :$fmt not available, use :csc, :csr or :coo.")
56+
throw(ArgumentError("Format :$fmt not available, use :csc, :csr or :coo."))
5757
end
5858
desc_dense = CuDenseMatrixDescriptor(A)
5959

@@ -82,8 +82,6 @@ function densetosparse(A::CuMatrix{T}, fmt::Symbol, index::SparseChar, algo::cus
8282
nzVal = CuVector{T}(undef, nnzB[])
8383
B = CuSparseMatrixCSC{T, Cint}(colPtr, rowVal, nzVal, (m,n))
8484
cusparseCscSetPointers(desc_sparse, B.colPtr, B.rowVal, B.nzVal)
85-
else
86-
error("Format :$fmt not available, use :csc, :csr or :coo.")
8785
end
8886
cusparseDenseToSparse_convert(handle(), desc_dense, desc_sparse, algo, buffer)
8987
end

test/libraries/cusparse/generic.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ for SparseMatrixType in [CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO
240240
dA_dense = CuMatrix{T}(A_dense)
241241
dA_sparse = CUSPARSE.densetosparse(dA_dense, fmt[SparseMatrixType], 'O', algo)
242242
@test A_sparse collect(dA_sparse)
243+
@test_throws ArgumentError("Format :bad not available, use :csc, :csr or :coo.") CUSPARSE.densetosparse(dA_dense, :bad, 'O', algo)
243244
end
244245
end
245246
@testset "$SparseMatrixType -- sparsetodense algo=$algo" for algo in [CUSPARSE.CUSPARSE_SPARSETODENSE_ALG_DEFAULT]
@@ -362,8 +363,25 @@ for SparseMatrixType in keys(SPGEMM_ALGOS)
362363
F = alpha * opa(A) * opb(B) + beta * E
363364
dF = gemm(transa, transb, alpha, dA, dB, beta, dE, 'O', algo, same_pattern=false)
364365
@test F SparseMatrixCSC(dF)
366+
367+
# not same pattern
368+
G = sprand(T, 25, 35, 0.4)
369+
dG = SparseMatrixType(G)
370+
@test_throws ErrorException("AB and C must have the same sparsity pattern.") gemm!(transa, transb, gamma, dA, dB, beta, dG, 'O', algo)
371+
dG = gemm!(transa, transb, gamma, dA, dB, zero(T), dG, 'O', algo)
372+
H = gamma * opa(A) * opa(B) + zero(T) * G
373+
@test H SparseMatrixCSC(dG)
365374
end
366375
end
376+
if SparseMatrixType == CuSparseMatrixCSR
377+
A = sprand(T,25,10,0.2)
378+
B = sprand(T,10,35,0.3)
379+
dA = SparseMatrixType(A)
380+
dB = SparseMatrixType(B)
381+
C = A * B
382+
dC = SparseMatrixType(C)
383+
@test_throws ArgumentError("Sparse matrix-matrix multiplication only supports transa (T) = 'N' and transb (C) = 'N'") gemm!('T', 'C', one(T), dA, dB, zero(T), dC, 'O', algo)
384+
end
367385
end
368386
end
369387

0 commit comments

Comments
 (0)