@@ -549,51 +549,85 @@ end
549549function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSR{T} ,
550550 B:: CuSparseMatrixCSR{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT) where {T}
551551
552- m,k = size (A)
553- l,n = size (B)
552+ m, k = size (A)
553+ l, n = size (B)
554554
555- (k != l) && throw (DimensionMismatch (" A must have the same the number of columns that B has as rows, but A has $k columns and B has $l columns " ))
555+ (k != l) && throw (DimensionMismatch (" A must have the same number of columns that B has as rows, but A has $k columns and B has $l rows. " ))
556556 ! (transa == ' N' && transb == ' N' ) && throw (ArgumentError (" Sparse matrix-matrix multiplication only supports transa ($transa ) = 'N' and transb ($transb ) = 'N'" ))
557557
558+ alpha_ref = Ref {T} (convert (T, alpha))
559+ beta_ref = Ref {T} (zero (T))
560+
558561 descA = CuSparseMatrixDescriptor (A, index)
559562 descB = CuSparseMatrixDescriptor (B, index)
560563
561- rowPtr = CuVector {Cint} (undef, m+ 1 )
564+ rowPtr = CuVector {Cint} (undef, m + 1 )
562565 descC = CuSparseMatrixDescriptor (CuSparseMatrixCSR, rowPtr, T, Cint, m, n, index)
563566
564567 spgemm_desc = CuSpGEMMDescriptor ()
565568
566569 buffer1 = CuVector {UInt8} (undef, 0 )
567570 buffer2 = CuVector {UInt8} (undef, 0 )
568- GC. @preserve buffer1 buffer1 begin
571+ GC. @preserve buffer1 buffer2 rowPtr begin
569572 # compute an upper bound of the memory required for the intermediate products.
570573 function buffer1Size ()
571574 out = Ref {Csize_t} (0 )
572575 cusparseSpGEMM_workEstimation (
573- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} ( 0 ) ,
576+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
574577 descC, T, algo, spgemm_desc, out, CU_NULL)
575578 return out[]
576579 end
577580 with_workspace (buffer1, buffer1Size) do buffer
578581 out = Ref {Csize_t} (sizeof (buffer))
579582 cusparseSpGEMM_workEstimation (
580- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} ( 0 ) ,
583+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
581584 descC, T, algo, spgemm_desc, out, buffer)
582585 end
583586
584587 # compute the structure of the output matrix and its values in a temporary buffer
585- function buffer2Size ()
586- out = Ref {Csize_t} (0 )
587- cusparseSpGEMM_compute (
588- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (0 ),
589- descC, T, algo, spgemm_desc, out, CU_NULL)
590- return out[]
591- end
592- with_workspace (buffer2, buffer2Size) do buffer
593- out = Ref {Csize_t} (sizeof (buffer))
594- cusparseSpGEMM_compute (
595- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (0 ),
596- descC, T, algo, spgemm_desc, out, buffer)
588+ if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
589+ function buffer2Size ()
590+ out = Ref {Csize_t} (0 )
591+ cusparseSpGEMM_compute (
592+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
593+ descC, T, algo, spgemm_desc, out, CU_NULL)
594+ return out[]
595+ end
596+ with_workspace (buffer2, buffer2Size) do buffer
597+ out = Ref {Csize_t} (sizeof (buffer))
598+ cusparseSpGEMM_compute (
599+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
600+ descC, T, algo, spgemm_desc, out, buffer)
601+ end
602+ elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
603+ chunk_fraction = Cfloat (0.2 ) # as per NVIDIA example (make it configurable?)
604+ function buffer3Size ()
605+ out = Ref {Csize_t} (0 )
606+ cusparseSpGEMM_estimateMemory (
607+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
608+ descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0 )
609+ return out[]
610+ end
611+ with_workspace (buffer3Size) do buffer3
612+ function buffer2Size ()
613+ out = Ref {Csize_t} (0 )
614+ cusparseSpGEMM_estimateMemory (
615+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
616+ descC, T, algo, spgemm_desc, chunk_fraction, sizeof (buffer3),
617+ buffer3, out)
618+ return out[]
619+ end
620+ with_workspace (buffer2, buffer2Size) do buffer
621+ unsafe_free! (buffer3)
622+
623+ out = Ref {Csize_t} (sizeof (buffer))
624+ cusparseSpGEMM_compute (
625+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
626+ descC, T, algo, spgemm_desc, out, buffer)
627+ end
628+ end
629+ else
630+ throw (ArgumentError (" Unsupported SpGEMM algorithm: $algo " ))
597631 end
598632 CUDA. unsafe_free! (buffer1)
599633
0 commit comments