@@ -108,6 +108,92 @@ function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
108108 end
109109end
110110
111+ # three-argument dot: avoid materializing A*y
112+ function LinearAlgebra. dot (x:: AnyCuArray{T1} , A:: AnyCuArray{T2} , y:: AnyCuArray{T3} ) where {T1,T2,T3}
113+ nx = length (x)
114+ ny = length (y)
115+ mA, nA = size (A)
116+
117+ nx == mA || throw (DimensionMismatch (" length of x, $nx , does not match first dimension of A, $mA " ))
118+ ny == nA || throw (DimensionMismatch (" length of y, $ny , does not match second dimension of A, $nA " ))
119+
120+ # custom kernel using simple linear indexing and atomic additions
121+ # COV_EXCL_START
122+ function kernel (x, A, y, res:: AbstractArray{T} , shuffle) where {T}
123+ local_val = zero (T)
124+
125+ # grid-stride loop over rows
126+ i = threadIdx (). x + (blockIdx (). x - 1 i32)* blockDim (). x
127+ while i <= mA
128+ row_val = zero (T)
129+ for j in 1 : nA
130+ # XXX : this is slow, but the focus is on avoiding materializing A*y
131+ @inbounds row_val += A[i, j] * y[j]
132+ end
133+ @inbounds local_val += LinearAlgebra. dot (x[i], row_val)
134+ i += blockDim (). x * gridDim (). x
135+ end
136+
137+ val = CUDA. reduce_block (+ , local_val, zero (T), shuffle)
138+ if threadIdx (). x == 1 i32
139+ # NOTE: introduces nondeterminism
140+ @inbounds CUDA. @atomic res[] += val
141+ end
142+
143+ return
144+ end
145+ # COV_EXCL_STOP
146+
147+ dev = device ()
148+ let T = promote_type (T1, T2, T3)
149+ # only use the above kernel if we don't care about determinism
150+ # and if atomic operations are supported on these inputs
151+ atomic = if capability (device ()) >= v " 7.0"
152+ T <: Union{Int16, Int32, Int64, Float16, Float32, Float64}
153+ else
154+ T <: Union{Int32, Int64, Float32, Float64}
155+ end
156+ if CUDA. math_mode () == CUDA. PEDANTIC_MATH || ! atomic
157+ bc = Base. Broadcast. broadcasted (A, Base. CartesianIndices (A)) do a, I
158+ i, j = Tuple (I)
159+ LinearAlgebra. dot (x[i], a * y[j])
160+ end
161+ return sum (bc)
162+ end
163+
164+ res = CUDA. zeros (T, 1 )
165+
166+ # be conservative about using shuffle instructions
167+ shuffle = T <: Union {Bool,
168+ UInt8, UInt16, UInt32, UInt64, UInt128,
169+ Int8, Int16, Int32, Int64, Int128,
170+ Float16, Float32, Float64,
171+ ComplexF16, ComplexF32, ComplexF64}
172+
173+ # how many threads do we want?
174+ # reduce_block(shuffle=true) requires the block to consist of full warps.
175+ wanted_threads = shuffle ? nextwarp (dev, mA) : mA
176+ function compute_threads (max_threads)
177+ if wanted_threads > max_threads
178+ shuffle ? prevwarp (dev, max_threads) : max_threads
179+ else
180+ wanted_threads
181+ end
182+ end
183+
184+ # how many threads can we launch?
185+ kernel_func = @cuda launch= false kernel (x, A, y, res, Val (shuffle))
186+ compute_shmem (threads) = shuffle ? 0 : threads* sizeof (T)
187+ config = launch_configuration (kernel_func. fun; shmem= compute_shmem∘ compute_threads)
188+ threads = compute_threads (config. threads)
189+ blocks = min (config. blocks, cld (mA, config. blocks))
190+ shmem = compute_shmem (threads)
191+ kernel_func (x, A, y, res, Val (shuffle); threads, blocks, shmem)
192+
193+ CUDA. @allowscalar res[]
194+ end
195+ end
196+
111197function LinearAlgebra.:(* )(transx:: Transpose{<:Any,<:StridedCuVector{T}} ,
112198 y:: StridedCuVector{T} ) where T<: Union{ComplexF16, CublasComplex}
113199 x = transx. parent
0 commit comments