Skip to content

Commit 87ea9d2

Browse files
maleadtclaude
andauthored
Fix dynamic function invocation in unsafe_strided_batch on Julia 1.12. (#3083)
Add `N` to the type parameters so the closure captures a fully-typed array, avoiding dynamic dispatch when the GPU compiler converts captured variables to device arrays. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5f45772 commit 87ea9d2

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

lib/cublas/src/wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ end
13411341
end
13421342

13431343
# create a batch of pointers in device memory from a strided device array
1344-
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
1344+
@inline function unsafe_strided_batch(strided::DenseCuArray{T, N}) where {T, N}
13451345
batch_size = last(size(strided))
13461346
batch_stride = prod(size(strided)[1:end-1])
13471347
#ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]

lib/cublas/test/extensions.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ k = 13
102102
end
103103
end
104104

105+
@testset "unsafe_strided_batch" begin
106+
A = CuArray(rand(elty, m, m, 10))
107+
ptrs = cuBLAS.unsafe_strided_batch(A)
108+
@test ptrs isa CuVector{CuPtr{elty}}
109+
@test length(ptrs) == 10
110+
end
111+
105112
@testset "getrf_strided_batched!" begin
106113
Random.seed!(1)
107114
local k

0 commit comments

Comments
 (0)