Skip to content

Commit 76692d5

Browse files
Added Base.similar methods for CuSparseMatrixCOO and BSR
1 parent 7a46bf3 commit 76692d5

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

lib/cusparse/src/array.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,22 @@ Base.similar(Mat::CuSparseMatrixCOO, dims::Tuple{Int, Int}) = similar(Mat, dims.
303303

304304
Base.similar(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
305305

306+
307+
function Base.similar(mat::CuSparseMatrixCOO, ::Type{T}, dims::Dims{2}) where {T}
308+
new_rowInd = similar(mat.rowInd)
309+
new_colInd = similar(mat.colInd)
310+
new_nzVal = similar(mat.nzVal, T)
311+
return CuSparseMatrixCOO(new_rowInd, new_colInd, new_nzVal, dims)
312+
end
313+
314+
function Base.similar(mat::CuSparseMatrixBSR{Tv, Ti}, ::Type{T}, dims::Dims{2}) where {Tv, Ti, T}
315+
new_rowPtr = similar(mat.rowPtr)
316+
new_colVal = similar(mat.colVal)
317+
new_nzVal = similar(mat.nzVal, T)
318+
319+
return CuSparseMatrixBSR{T, Ti}(new_rowPtr, new_colVal, new_nzVal, dims, mat.blockDim, mat.dir, mat.nnzb)
320+
end
321+
306322
## array interface
307323

308324
Base.length(g::CuSparseVector) = g.len

0 commit comments

Comments
 (0)