@@ -57,15 +57,15 @@ function Base.reshape(A::CuSparseMatrixCOO, dims::Dims)
5757 sparse (new_row, new_col, A. nzVal, dims[1 ], length (dims) == 1 ? 1 : dims[2 ], fmt = :coo )
5858end
5959
60- function LinearAlgebra. kron (A:: CuSparseMatrixCOO{T} , B:: CuSparseMatrixCOO{T} ) where {T}
60+ function LinearAlgebra. kron (A:: CuSparseMatrixCOO{T, Ti } , B:: CuSparseMatrixCOO{T, Ti } ) where {Ti, T}
6161 mA,nA = size (A)
6262 mB,nB = size (B)
6363 out_shape = (mA * mB, nA * nB)
6464 Annz = Int64 (A. nnz)
6565 Bnnz = Int64 (B. nnz)
6666
6767 if Annz == 0 || Bnnz == 0
68- return CuSparseMatrixCOO (CuVector {T } (undef, 0 ), CuVector {T } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
68+ return CuSparseMatrixCOO (CuVector {Ti } (undef, 0 ), CuVector {Ti } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
6969 end
7070
7171 row = (A. rowInd .- 1 ) .* mB
@@ -82,15 +82,15 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}) wh
8282 sparse (row, col, data, out_shape... , fmt = :coo )
8383end
8484
85- function LinearAlgebra. kron (A:: CuSparseMatrixCOO{T} , B:: Diagonal ) where {T}
85+ function LinearAlgebra. kron (A:: CuSparseMatrixCOO{T, Ti } , B:: Diagonal ) where {Ti, T}
8686 mA,nA = size (A)
8787 mB,nB = size (B)
8888 out_shape = (mA * mB, nA * nB)
8989 Annz = Int64 (A. nnz)
9090 Bnnz = nB
9191
9292 if Annz == 0 || Bnnz == 0
93- return CuSparseMatrixCOO (CuVector {T } (undef, 0 ), CuVector {T } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
93+ return CuSparseMatrixCOO (CuVector {Ti } (undef, 0 ), CuVector {Ti } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
9494 end
9595
9696 row = (A. rowInd .- 1 ) .* mB
@@ -107,15 +107,15 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T}, B::Diagonal) where {T}
107107 sparse (row, col, data, out_shape... , fmt = :coo )
108108end
109109
110- function LinearAlgebra. kron (A:: Diagonal , B:: CuSparseMatrixCOO{T} ) where {T}
110+ function LinearAlgebra. kron (A:: Diagonal , B:: CuSparseMatrixCOO{T, Ti } ) where {Ti, T}
111111 mA,nA = size (A)
112112 mB,nB = size (B)
113113 out_shape = (mA * mB, nA * nB)
114114 Annz = nA
115115 Bnnz = Int64 (B. nnz)
116116
117117 if Annz == 0 || Bnnz == 0
118- return CuSparseMatrixCOO (CuVector {T } (undef, 0 ), CuVector {T } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
118+ return CuSparseMatrixCOO (CuVector {Ti } (undef, 0 ), CuVector {Ti } (undef, 0 ), CuVector {T} (undef, 0 ), out_shape)
119119 end
120120
121121 row = (0 : nA- 1 ) .* mB
0 commit comments