Skip to content

Commit b4ac08a

Browse files
authored
Add support for 3-arg dot. (#2914)
1 parent 0b96ba1 commit b4ac08a

3 files changed

Lines changed: 93 additions & 1 deletion

File tree

lib/cublas/linalg.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,92 @@ function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
108108
end
109109
end
110110

111+
# three-argument dot: avoid materializing A*y
112+
function LinearAlgebra.dot(x::AnyCuArray{T1}, A::AnyCuArray{T2}, y::AnyCuArray{T3}) where {T1,T2,T3}
113+
nx = length(x)
114+
ny = length(y)
115+
mA, nA = size(A)
116+
117+
nx == mA || throw(DimensionMismatch("length of x, $nx, does not match first dimension of A, $mA"))
118+
ny == nA || throw(DimensionMismatch("length of y, $ny, does not match second dimension of A, $nA"))
119+
120+
# custom kernel using simple linear indexing and atomic additions
121+
# COV_EXCL_START
122+
function kernel(x, A, y, res::AbstractArray{T}, shuffle) where {T}
123+
local_val = zero(T)
124+
125+
# grid-stride loop over rows
126+
i = threadIdx().x + (blockIdx().x - 1i32)*blockDim().x
127+
while i <= mA
128+
row_val = zero(T)
129+
for j in 1:nA
130+
# XXX: this is slow, but the focus is on avoiding materializing A*y
131+
@inbounds row_val += A[i, j] * y[j]
132+
end
133+
@inbounds local_val += LinearAlgebra.dot(x[i], row_val)
134+
i += blockDim().x * gridDim().x
135+
end
136+
137+
val = CUDA.reduce_block(+, local_val, zero(T), shuffle)
138+
if threadIdx().x == 1i32
139+
# NOTE: introduces nondeterminism
140+
@inbounds CUDA.@atomic res[] += val
141+
end
142+
143+
return
144+
end
145+
# COV_EXCL_STOP
146+
147+
dev = device()
148+
let T = promote_type(T1, T2, T3)
149+
# only use the above kernel if we don't care about determinism
150+
# and if atomic operations are supported on these inputs
151+
atomic = if capability(device()) >= v"7.0"
152+
T <: Union{Int16, Int32, Int64, Float16, Float32, Float64}
153+
else
154+
T <: Union{Int32, Int64, Float32, Float64}
155+
end
156+
if CUDA.math_mode() == CUDA.PEDANTIC_MATH || !atomic
157+
bc = Base.Broadcast.broadcasted(A, Base.CartesianIndices(A)) do a, I
158+
i, j = Tuple(I)
159+
LinearAlgebra.dot(x[i], a * y[j])
160+
end
161+
return sum(bc)
162+
end
163+
164+
res = CUDA.zeros(T, 1)
165+
166+
# be conservative about using shuffle instructions
167+
shuffle = T <: Union{Bool,
168+
UInt8, UInt16, UInt32, UInt64, UInt128,
169+
Int8, Int16, Int32, Int64, Int128,
170+
Float16, Float32, Float64,
171+
ComplexF16, ComplexF32, ComplexF64}
172+
173+
# how many threads do we want?
174+
# reduce_block(shuffle=true) requires the block to consist of full warps.
175+
wanted_threads = shuffle ? nextwarp(dev, mA) : mA
176+
function compute_threads(max_threads)
177+
if wanted_threads > max_threads
178+
shuffle ? prevwarp(dev, max_threads) : max_threads
179+
else
180+
wanted_threads
181+
end
182+
end
183+
184+
# how many threads can we launch?
185+
kernel_func = @cuda launch=false kernel(x, A, y, res, Val(shuffle))
186+
compute_shmem(threads) = shuffle ? 0 : threads*sizeof(T)
187+
config = launch_configuration(kernel_func.fun; shmem=compute_shmemcompute_threads)
188+
threads = compute_threads(config.threads)
189+
blocks = min(config.blocks, cld(mA, config.blocks))
190+
shmem = compute_shmem(threads)
191+
kernel_func(x, A, y, res, Val(shuffle); threads, blocks, shmem)
192+
193+
CUDA.@allowscalar res[]
194+
end
195+
end
196+
111197
function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}},
112198
y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
113199
x = transx.parent

src/mapreduce.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ end
168168
function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T},
169169
A::Union{AbstractArray,Broadcast.Broadcasted};
170170
init=nothing) where {F, OP, T}
171-
Base.check_reducedims(R, A)
171+
if !isa(A, Broadcast.Broadcasted)
172+
# XXX: Base.axes isn't defined anymore for Broadcasted, breaking this check
173+
Base.check_reducedims(R, A)
174+
end
172175
length(A) == 0 && return R # isempty(::Broadcasted) iterates
173176
dev = device()
174177

test/base/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313
ComplexF16, ComplexF32, ComplexF64]
1414
@test testf(dot, rand(T, 256), rand(Bool, 256))
1515
@test testf(dot, rand(Bool, 256), rand(T, 256))
16+
17+
@test testf(dot, rand(T, 256), rand(T, 256, 256), rand(Bool, 256))
18+
@test testf(dot, rand(Bool, 256), rand(T, 256, 256), rand(T, 256))
1619
end
1720

1821
@test testf(dot, rand(Bool, 1024, 1024), rand(Float64, 1024, 1024))

0 commit comments

Comments
 (0)