@@ -327,6 +327,41 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrix
327327 end
328328end
329329
330+ # Wrapper around `cusparseCsr2cscEx2` that works around a cuSPARSE 12.0 bug:
331+ # invoking the routine with one-based indexing silently corrupts `cscVal` for
332+ # matrices with certain shapes (e.g. even `m`). On affected versions we shift
333+ # the index arrays to zero-based around the call; zero-based indexing returns
334+ # correct results.
335+ function _csr2cscEx2! (m, n, nnz_, csrVal:: CuVector{T} , csrRowPtr, csrColInd,
336+ cscVal:: CuVector{T} , cscColPtr, cscRowInd,
337+ action, index, algo) where T
338+ buggy = index == ' O' && v " 12.0" <= version () < v " 12.1"
339+ if buggy
340+ csrRowPtr = csrRowPtr .- one (Cint)
341+ csrColInd = csrColInd .- one (Cint)
342+ effidx = ' Z'
343+ else
344+ effidx = index
345+ end
346+ function bufferSize ()
347+ out = Ref {Csize_t} (1 )
348+ cusparseCsr2cscEx2_bufferSize (handle (), m, n, nnz_, csrVal,
349+ csrRowPtr, csrColInd, cscVal, cscColPtr, cscRowInd,
350+ T, action, effidx, algo, out)
351+ return out[]
352+ end
353+ with_workspace (bufferSize) do buffer
354+ cusparseCsr2cscEx2 (handle (), m, n, nnz_, csrVal,
355+ csrRowPtr, csrColInd, cscVal, cscColPtr, cscRowInd,
356+ T, action, effidx, algo, buffer)
357+ end
358+ if buggy
359+ cscColPtr .+ = one (Cint)
360+ cscRowInd .+ = one (Cint)
361+ end
362+ return
363+ end
364+
330365# by flipping rows and columns, we can use that to get CSC to CSR too
331366for elty in (:Float32 , :Float64 , :ComplexF32 , :ComplexF64 )
332367 @eval begin
@@ -335,18 +370,8 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
335370 colPtr = CUDACore. zeros (Cint, n+ 1 )
336371 rowVal = CUDACore. zeros (Cint, nnz (csr))
337372 nzVal = CUDACore. zeros ($ elty, nnz (csr))
338- function bufferSize ()
339- out = Ref {Csize_t} (1 )
340- cusparseCsr2cscEx2_bufferSize (handle (), m, n, nnz (csr), nonzeros (csr),
341- csr. rowPtr, csr. colVal, nzVal, colPtr, rowVal,
342- $ elty, action, index, algo, out)
343- return out[]
344- end
345- with_workspace (bufferSize) do buffer
346- cusparseCsr2cscEx2 (handle (), m, n, nnz (csr), nonzeros (csr),
347- csr. rowPtr, csr. colVal, nzVal, colPtr, rowVal,
348- $ elty, action, index, algo, buffer)
349- end
373+ _csr2cscEx2! (m, n, nnz (csr), nonzeros (csr), csr. rowPtr, csr. colVal,
374+ nzVal, colPtr, rowVal, action, index, algo)
350375 CuSparseMatrixCSC (colPtr,rowVal,nzVal,size (csr))
351376 end
352377 CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty, Ti} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1) where {Ti} =
@@ -358,18 +383,8 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
358383 rowPtr = CUDACore. zeros (Cint,m+ 1 )
359384 colVal = CUDACore. zeros (Cint,nnz (csc))
360385 nzVal = CUDACore. zeros ($ elty,nnz (csc))
361- function bufferSize ()
362- out = Ref {Csize_t} (1 )
363- cusparseCsr2cscEx2_bufferSize (handle (), n, m, nnz (csc), nonzeros (csc),
364- csc. colPtr, rowvals (csc), nzVal, rowPtr, colVal,
365- $ elty, action, index, algo, out)
366- return out[]
367- end
368- with_workspace (bufferSize) do buffer
369- cusparseCsr2cscEx2 (handle (), n, m, nnz (csc), nonzeros (csc),
370- csc. colPtr, rowvals (csc), nzVal, rowPtr, colVal,
371- $ elty, action, index, algo, buffer)
372- end
386+ _csr2cscEx2! (n, m, nnz (csc), nonzeros (csc), csc. colPtr, rowvals (csc),
387+ nzVal, rowPtr, colVal, action, index, algo)
373388 CuSparseMatrixCSR (rowPtr,colVal,nzVal,size (csc))
374389 end
375390 CuSparseMatrixCSR (csc:: CuSparseMatrixCSC{$elty, Ti} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1) where {Ti} =
@@ -390,18 +405,8 @@ for (elty, welty) in ((:Float16, :Float32),
390405 rowVal = CUDACore. zeros (Cint, nnz (csr))
391406 nzVal = CUDACore. zeros ($ elty, nnz (csr))
392407 if $ elty == Float16 # broken for ComplexF16?
393- function bufferSize ()
394- out = Ref {Csize_t} (1 )
395- cusparseCsr2cscEx2_bufferSize (handle (), m, n, nnz (csr), nonzeros (csr),
396- csr. rowPtr, csr. colVal, nzVal, colPtr, rowVal,
397- $ elty, action, index, algo, out)
398- return out[]
399- end
400- with_workspace (bufferSize) do buffer
401- cusparseCsr2cscEx2 (handle (), m, n, nnz (csr), nonzeros (csr),
402- csr. rowPtr, csr. colVal, nzVal, colPtr, rowVal,
403- $ elty, action, index, algo, buffer)
404- end
408+ _csr2cscEx2! (m, n, nnz (csr), nonzeros (csr), csr. rowPtr, csr. colVal,
409+ nzVal, colPtr, rowVal, action, index, algo)
405410 return CuSparseMatrixCSC (colPtr,rowVal,nzVal,size (csr))
406411 else
407412 wide_csr = CuSparseMatrixCSR (csr. rowPtr, csr. colVal, convert (CuVector{$ welty}, nonzeros (csr)), size (csr))
@@ -419,18 +424,8 @@ for (elty, welty) in ((:Float16, :Float32),
419424 colVal = CUDACore. zeros (Cint,nnz (csc))
420425 nzVal = CUDACore. zeros ($ elty,nnz (csc))
421426 if $ elty == Float16 # broken for ComplexF16?
422- function bufferSize ()
423- out = Ref {Csize_t} (1 )
424- cusparseCsr2cscEx2_bufferSize (handle (), n, m, nnz (csc), nonzeros (csc),
425- csc. colPtr, rowvals (csc), nzVal, rowPtr, colVal,
426- $ elty, action, index, algo, out)
427- return out[]
428- end
429- with_workspace (bufferSize) do buffer
430- cusparseCsr2cscEx2 (handle (), n, m, nnz (csc), nonzeros (csc),
431- csc. colPtr, rowvals (csc), nzVal, rowPtr, colVal,
432- $ elty, action, index, algo, buffer)
433- end
427+ _csr2cscEx2! (n, m, nnz (csc), nonzeros (csc), csc. colPtr, rowvals (csc),
428+ nzVal, rowPtr, colVal, action, index, algo)
434429 return CuSparseMatrixCSR (rowPtr,colVal,nzVal,size (csc))
435430 else
436431 wide_csc = CuSparseMatrixCSC (csc. colPtr, csc. rowVal, convert (CuVector{$ welty}, nonzeros (csc)), size (csc))
0 commit comments