Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 35 additions & 38 deletions lib/cusparse/src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,22 @@ Base.similar(Mat::CuSparseMatrixCOO, dims::Tuple{Int, Int}) = similar(Mat, dims.

Base.similar(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))


function Base.similar(mat::CuSparseMatrixCOO, ::Type{T}, dims::Dims{2}) where {T}
new_rowInd = similar(mat.rowInd)
new_colInd = similar(mat.colInd)
new_nzVal = similar(mat.nzVal, T)
return CuSparseMatrixCOO{T, Ti}(new_rowInd, new_colInd, new_nzVal, dims)
end

function Base.similar(mat::CuSparseMatrixBSR{Tv, Ti}, ::Type{T}, dims::Dims{2}) where {Tv, Ti, T}
new_rowPtr = similar(mat.rowPtr)
new_colVal = similar(mat.colVal)
new_nzVal = similar(mat.nzVal, T)

return CuSparseMatrixBSR{T, Ti}(new_rowPtr, new_colVal, new_nzVal, dims, mat.blockDim, mat.dir, mat.nnzb)
end

## array interface

Base.length(g::CuSparseVector) = g.len
Expand Down Expand Up @@ -462,57 +478,39 @@ Base.getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse

function Base.getindex(A::CuSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
@boundscheck checkbounds(A, i)
result = zero(Tv)
for k in 1:nnz(A)
A.iPtr[k] == i && (result = sum_duplicate(result, A.nzVal[k]))
end
return result
ii = searchsortedfirst(A.iPtr, convert(Ti, i))
Comment thread
rainerrodrigues marked this conversation as resolved.
(ii > nnz(A) || A.iPtr[ii] != i) && return zero(Tv)
A.nzVal[ii]
end

# Scalar getindex methods linear-scan the minor axis rather than binary-searching
# and sum across matching entries. cuSPARSE formats don't guarantee sorted indices
# within a major-axis slice (e.g. SpGEMM output may leave CSR columns unsorted
# within a row, and COO is only guaranteed row-sorted), nor uniqueness — duplicate
# (i, j) entries are permitted and their values sum, matching the convention of
# Julia's `sparse()` constructor and SciPy/CuPy. For Bool we OR instead of sum,
# also matching `sparse()`, since Bool + Bool doesn't stay Bool.
sum_duplicate(a, b) = a + b
sum_duplicate(a::Bool, b::Bool) = a | b

function Base.getindex(A::CuSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
@boundscheck checkbounds(A, i0, i1)
r1 = Int(A.colPtr[i1])
r2 = Int(A.colPtr[i1+1]-1)
result = zero(T)
for k in r1:r2
rowvals(A)[k] == i0 && (result = sum_duplicate(result, nonzeros(A)[k]))
end
return result
(r1 > r2) && return zero(T)
r1 = searchsortedfirst(rowvals(A), i0, r1, r2, Base.Order.Forward)
(r1 > r2 || rowvals(A)[r1] != i0) && return zero(T)
nonzeros(A)[r1]
end

function Base.getindex(A::CuSparseMatrixCSR{T}, i0::Integer, i1::Integer) where T
@boundscheck checkbounds(A, i0, i1)
c1 = Int(A.rowPtr[i0])
c2 = Int(A.rowPtr[i0+1]-1)
result = zero(T)
for k in c1:c2
A.colVal[k] == i1 && (result = sum_duplicate(result, nonzeros(A)[k]))
end
return result
(c1 > c2) && return zero(T)
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
(c1 > c2 || A.colVal[c1] != i1) && return zero(T)
nonzeros(A)[c1]
end

function Base.getindex(A::CuSparseMatrixCOO{T}, i0::Integer, i1::Integer) where T
@boundscheck checkbounds(A, i0, i1)
# cuSPARSE only guarantees COO is sorted by row, so binary-search the row
# range but linear-scan for the column.
r1 = searchsortedfirst(A.rowInd, i0, Base.Order.Forward)
(r1 > length(A.rowInd) || A.rowInd[r1] > i0) && return zero(T)
r2 = searchsortedlast(A.rowInd, i0, Base.Order.Forward)
result = zero(T)
for k in r1:r2
A.colInd[k] == i1 && (result = sum_duplicate(result, nonzeros(A)[k]))
end
return result
r2 = min(searchsortedfirst(A.rowInd, i0+1, Base.Order.Forward), length(A.rowInd))
c1 = searchsortedfirst(A.colInd, i1, r1, r2, Base.Order.Forward)
(c1 > r2 || c1 == length(A.colInd) + 1 || A.colInd[c1] > i1) && return zero(T)
nonzeros(A)[c1]
end

function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where T
Expand All @@ -522,11 +520,10 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
block_idx = (i0_idx - 1) * A.blockDim + i1_idx - 1
c1 = Int(A.rowPtr[i0_block])
c2 = Int(A.rowPtr[i0_block+1]-1)
result = zero(T)
for k in c1:c2
A.colVal[k] == i1_block && (result = sum_duplicate(result, nonzeros(A)[k+block_idx]))
end
return result
(c1 > c2) && return zero(T)
c1 = searchsortedfirst(A.colVal, i1_block, c1, c2, Base.Order.Forward)
(c1 > c2 || A.colVal[c1] != i1_block) && return zero(T)
nonzeros(A)[c1+block_idx]
end

# matrix slices
Expand Down