forked from JuliaGPU/CUDA.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdevice.jl
More file actions
144 lines (119 loc) · 5.07 KB
/
device.jl
File metadata and controls
144 lines (119 loc) · 5.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# on-device sparse array functionality
# should be excluded from coverage counts
# COV_EXCL_START
using SparseArrays
# NOTE: this functionality is currently very bare-bones, only defining the array types
# without any device-compatible sparse array functionality
# core types
export CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR,
CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO
struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti}
iPtr::CuDeviceVector{Ti, A}
nzVal::CuDeviceVector{Tv, A}
len::Int
nnz::Ti
end
Base.length(g::CuSparseDeviceVector) = prod(g.dims)
Base.size(g::CuSparseDeviceVector) = (g.len,)
SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz
struct CuSparseDeviceMatrixCSC{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
colPtr::CuDeviceVector{Ti, A}
rowVal::CuDeviceVector{Ti, A}
nzVal::CuDeviceVector{Tv, A}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCSC) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCSC) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCSC) = g.nnz
SparseArrays.rowvals(g::CuSparseDeviceMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::CuSparseDeviceMatrixCSC) = g.colPtr
SparseArrays.getnzval(g::CuSparseDeviceMatrixCSC) = g.nzVal
SparseArrays.nzrange(g::CuSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
struct CuSparseDeviceMatrixCSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowPtr::CuDeviceVector{Ti, A}
colVal::CuDeviceVector{Ti, A}
nzVal::CuDeviceVector{Tv, A}
dims::NTuple{2, Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCSR) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCSR) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCSR) = g.nnz
SparseArrays.getnzval(g::CuSparseDeviceMatrixCSR) = g.nzVal
struct CuSparseDeviceMatrixBSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowPtr::CuDeviceVector{Ti, A}
colVal::CuDeviceVector{Ti, A}
nzVal::CuDeviceVector{Tv, A}
dims::NTuple{2,Int}
blockDim::Ti
dir::Char
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixBSR) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixBSR) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixBSR) = g.nnz
SparseArrays.getnzval(g::CuSparseDeviceMatrixBSR) = g.nzVal
struct CuSparseDeviceMatrixCOO{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
rowInd::CuDeviceVector{Ti, A}
colInd::CuDeviceVector{Ti, A}
nzVal::CuDeviceVector{Tv, A}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(g::CuSparseDeviceMatrixCOO) = prod(g.dims)
Base.size(g::CuSparseDeviceMatrixCOO) = g.dims
SparseArrays.nnz(g::CuSparseDeviceMatrixCOO) = g.nnz
SparseArrays.getnzval(g::CuSparseDeviceMatrixCOO) = g.nzVal
struct CuSparseDeviceArrayCSR{Tv, Ti, N, M, A} <: AbstractSparseArray{Tv, Ti, N}
rowPtr::CuDeviceArray{Ti, M, A}
colVal::CuDeviceArray{Ti, M, A}
nzVal::CuDeviceArray{Tv, M, A}
dims::NTuple{N, Int}
nnz::Ti
end
function CuSparseDeviceArrayCSR{Tv, Ti, N, A}(rowPtr::CuArray{<:Integer, M}, colVal::CuArray{<:Integer, M}, nzVal::CuArray{Tv, M}, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, N, A}
@assert M == N - 1 "CuSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
CuSparseDeviceArrayCSR{Tv, Ti, N, M, A}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
Base.length(g::CuSparseDeviceArrayCSR) = prod(g.dims)
Base.size(g::CuSparseDeviceArrayCSR) = g.dims
SparseArrays.nnz(g::CuSparseDeviceArrayCSR) = g.nnz
SparseArrays.getnzval(g::CuSparseDeviceArrayCSR) = g.nzVal
# input/output
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceVector)
println(io, "$(length(A))-element device sparse vector at:")
println(io, " iPtr: $(A.iPtr)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSR)
println(io, "$(length(A))-element device sparse matrix CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSC)
println(io, "$(length(A))-element device sparse matrix CSC at:")
println(io, " colPtr: $(A.colPtr)")
println(io, " rowVal: $(A.rowVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixBSR)
println(io, "$(length(A))-element device sparse matrix BSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCOO)
println(io, "$(length(A))-element device sparse matrix COO at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colInd: $(A.colInd)")
print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceArrayCSR)
println(io, "$(length(A))-element device sparse array CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end
# COV_EXCL_STOP