Skip to content

Adapt to CUDA.jl v6#2749

Merged
avik-pal merged 18 commits intomainfrom
mg/cuda-6
Apr 19, 2026
Merged

Adapt to CUDA.jl v6#2749
avik-pal merged 18 commits intomainfrom
mg/cuda-6

Conversation

@giordano
Copy link
Copy Markdown
Member

@giordano giordano commented Mar 30, 2026

Not quite ready, especially because it depends on upstream packages (ArrayInterface, Flux, Lux, LuxLib, NNlib, NonuniformFFTs, OneHotArrays) to adapt to the upcoming CUDA v6 first, but I'm saving my progress so far, at least with these changes I can barely precompile the CUDA extension.

@giordano giordano marked this pull request as draft March 30, 2026 16:57
@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Mar 30, 2026

is it possible to wait on this for 2 weeks?

Comment thread ext/ReactantCUDAExt.jl
CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo),
CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx);
GPUCompiler.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo),
CUDACore.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this presumably breaks on cuda 5 right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is right now yes, but I'm pretty sure we can simply define const CUDACore = CUDA when CUDACore isn't defined, thus making all changes compatible with v5 as well. And I pushed some changes in CUDA.jl itself to break fewer things (like making all the @device_* macros available in the CUDA scope)

Copy link
Copy Markdown
Member Author

@giordano giordano Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With cff4b0a the extension should be fully compatible with both CUDA v5 and v6.

@giordano
Copy link
Copy Markdown
Member Author

is it possible to wait on this for 2 weeks?

As I mentioned above and elsewhere, this requires a lot of other packages to update to CUDA.jl v6 (which isn't even released)

@giordano
Copy link
Copy Markdown
Member Author

Side note, half of changes to the extension are actually bug fixes independent of the upgrade to v6 (that only exposed the bugs), like trying to symbols from wrong modules.

Comment thread ext/ReactantCUDAExt.jl Outdated
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@giordano giordano marked this pull request as ready for review April 17, 2026 17:27
@giordano
Copy link
Copy Markdown
Member Author

This is ready for review. I'd like to tag the new version after this is merged, some users already reported issues with Reactant being downgraded to very old versions when installing CUDA in the same environment.

@giordano giordano requested a review from wsmoses April 17, 2026 17:30
@avik-pal
Copy link
Copy Markdown
Collaborator

KA tests are broken

@giordano
Copy link
Copy Markdown
Member Author

giordano commented Apr 18, 2026

it'd make my life easier to get a direct link, instead of having to scavenge the tests.

@giordano
Copy link
Copy Markdown
Member Author

https://github.com/EnzymeAD/Reactant.jl/actions/runs/24578105753/job/71868640907?pr=2749#step:23:1470

ERROR: The following 1 direct dependency failed to precompile:

CUDAExt --code-coverage=@/home/runner/work/Reactant.jl/Reactant.jl --color=yes --check-bounds=yes --warn-overwrite=yes --depwarn=yes --inline=yes --startup-file=no --track-allocation=none --check-bounds=yes --compiled-modules=yes --depwarn=yes 

Failed to precompile CUDAExt [8d20f71a-eaa5-5402-8cb1-1e6062ff668e] to "/home/runner/.julia/compiled/v1.11/CUDAExt/jl_RwIJ4b".
WARNING: importing deprecated binding CUDA.CUBLAS into CUDAExt.
, use cuBLAS instead.
ERROR: LoadError: UndefVarError: `APIUtils` not defined in `CUDA`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base ./Base.jl:42
 [2] top-level scope
   @ ~/.julia/packages/LuxLib/lngmK/ext/CUDAExt/cublaslt.jl:72
 [3] include(mod::Module, _path::String)
   @ Base ./Base.jl:562
 [4] include(x::String)
   @ CUDAExt ~/.julia/packages/LuxLib/lngmK/ext/CUDAExt/CUDAExt.jl:1
 [5] top-level scope
   @ ~/.julia/packages/LuxLib/lngmK/ext/CUDAExt/CUDAExt.jl:11
 [6] include
   @ ./Base.jl:562 [inlined]
 [7] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing)
   @ Base ./loading.jl:2924
 [8] top-level scope
   @ stdin:6
in expression starting at /home/runner/.julia/packages/LuxLib/lngmK/ext/CUDAExt/cublaslt.jl:72
in expression starting at /home/runner/.julia/packages/LuxLib/lngmK/ext/CUDAExt/CUDAExt.jl:1
in expression starting at stdin:

Looks to me a bug in LuxLib

@giordano
Copy link
Copy Markdown
Member Author

giordano commented Apr 18, 2026

I'm going to assume the "broken KA tests" are https://buildkite.com/julialang/reactant-dot-jl/builds/17726#019d9c7c-885e-47ad-94da-9f8c2b1de61b/L2389. A standalone reproducer is (requires an Nvidia GPU)

julia> using CUDA, KernelAbstractions, Reactant

julia> @kernel function square_kernel!(y, @Const(x))
           i = @index(Global)
           @inbounds y[i] = x[i] * x[i]
       end
square_kernel! (generic function with 4 methods)

julia> function square(x)
           y = similar(x)
           backend = KernelAbstractions.get_backend(x)
           kernel! = square_kernel!(backend)
           kernel!(y, x; ndrange=length(x))
           return y
       end
square (generic function with 1 method)

julia> x = Reactant.to_rarray(collect(1:1:64) ./ 64);

julia> @jit(raise = false, square(x));
Warning: detected a stack overflow; program state may be corrupted, so further execution might be unreliable.
Warning: detected a stack overflow; program state may be corrupted, so further execution might be unreliable.
ERROR: StackOverflowError:
Stacktrace:
     [1] rethrow()
       @ Base ./error.jl:71
     [2] macro expansion
       @ ./lock.jl:378 [inlined]
     [3] cufunction(f::typeof(gpu_square_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})
       @ ReactantCUDAExt /mnt/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl:1599
     [4] call_with_reactant
       @ ./none:-1 [inlined]
     [5] call_with_reactant(::Reactant.EnsureReturnType{Any}, ::typeof(cufunction), ::typeof(gpu_square_kernel!), ::Type{Tuple{…}})
       @ Reactant /mnt/giordano/.julia/dev/Reactant/src/utils.jl:0
     [6] #launch_configuration#9
       @ /mnt/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl:614
     [7] call_with_reactant
       @ ./none:-1 [inlined]
     [8] call_with_reactant(::Reactant.EnsureReturnType{…}, ::ReactantCUDAExt.var"##launch_configuration#9", ::Int64, ::Int64, ::typeof(launch_configuration), ::Reactant.Compiler.LLVMFunc{…})
       @ Reactant /mnt/giordano/.julia/dev/Reactant/src/utils.jl:0
--- the above 2 lines are repeated 1 more time ---
--- the above 5 lines are repeated 6868 more times ---
 [34351] ka_with_reactant
       @ /mnt/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl:535
 [34352] call_with_reactant
       @ ./none:-1 [inlined]
 [34353] call_with_reactant(::typeof(Reactant.ka_with_reactant), ::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
       @ Reactant /mnt/giordano/.julia/dev/Reactant/src/utils.jl:0
 [34354] (::KernelAbstractions.Kernel{…})(::Reactant.TracedRArray{…}, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
       @ ReactantKernelAbstractionsExt /mnt/giordano/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:128
 [34355] square
       @ ./REPL[5]:5
 [34356] call_with_reactant
       @ ./none:-1 [inlined]
 [34357] call_with_reactant(::typeof(square), ::Reactant.TracedRArray{Float64, 1})
       @ Reactant /mnt/giordano/.julia/dev/Reactant/src/utils.jl:0
 [34358] make_mlir_fn(f::typeof(square), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
       @ Reactant.TracedUtils /mnt/giordano/.julia/dev/Reactant/src/TracedUtils.jl:370

@giordano
Copy link
Copy Markdown
Member Author

giordano commented Apr 18, 2026

I'll need some help for digging this down. I followed

function Reactant.ka_with_reactant(ndrange, workgroupsize, obj, args...)
backend = KA.backend(obj)
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(
obj, ndrange, workgroupsize
)
# this might not be the final context, since we may tune the workgroupsize
ctx = KA.mkcontext(obj, ndrange, iterspace)
# If the kernel is statically sized we can tell the compiler about that
if KA.workgroupsize(obj) <: KA.StaticSize
maxthreads = prod(KA.get(KA.workgroupsize(obj)))
else
maxthreads = nothing
end
kernel = CUDA.@cuda launch = false always_inline = backend.always_inline maxthreads =
maxthreads obj.f(ctx, args...)
and tried

using CUDA, KernelAbstractions, Reactant

const KA = KernelAbstractions

@kernel function square_kernel!(y, @Const(x))
    i = @index(Global)
    @inbounds y[i] = x[i] * x[i]
end

x = Reactant.to_rarray(collect(1:1:64) ./ 64);
y = similar(x);
backend = KernelAbstractions.get_backend(x)
kernel! = square_kernel!(backend)

ndrange, workgroupsize = length(x), nothing
obj = kernel!
args = (y, x);

ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(
    obj, ndrange, workgroupsize
)
ctx = KA.mkcontext(obj, ndrange, iterspace)
maxthreads = nothing

kernel = CUDA.@cuda launch = false always_inline = backend.always_inline maxthreads =
    maxthreads obj.f(ctx, args...)

but I get various errors both on CUDA.@cuda both on main and this branch, so I'm likely doing something wrong. What's the flow for compiling a KernelAbstraction kernel?

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Apr 18, 2026

@maleadt was there any change to cufunction/friends?

@giordano
Copy link
Copy Markdown
Member Author

The stacktrace suggests the stackoverflow happens in CUDA.launch_configuration at

config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))
but in my attempt to reduce the error I'm stuck before getting to that line.

@giordano
Copy link
Copy Markdown
Member Author

How does

@noinline function CUDA.launch_configuration(
f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0
) where {F,tt}
return CUDA.launch_configuration(
Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;
shmem,
max_threads,
)
end
work? cufunction is overlayed at
Reactant.@reactant_overlay @noinline function CUDA.cufunction(
f::F, tt::TT=Tuple{}; kwargs...
) where {F,TT}
res = Base.@lock CUDA.cufunction_lock begin
# compile the function
cache = llvm_compiler_cache(MLIR.IR.current_module())
effective_tt = _substitute_bfloat16_tt(
tt, Reactant.Compiler.BFLOAT16_COMPILE_TYPE[]
)
source = CUDA.methodinstance(F, effective_tt)
# cuda = CUDA.active_state()
device = nothing # cuda.device
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
cuda_cap = v"5.0"
cuda_ptx = v"6.3"
llvm_cap = v"5.0"
llvm_ptx = v"6.3"
kernel = true
always_inline = false
name = nothing
debuginfo = false
config = GPUCompiler.CompilerConfig(
CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo),
CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx);
kernel,
name,
always_inline,
optimize=false,
cleanup=false,
validate=false,
libraries=false,
)
GPUCompiler.cached_compilation(cache, source, config, compile, link)
end
return Core.Typeof(res)(f, res.entry)
end
but as far as I can tell it returns Reactant.Compiler.LLVMFunc also on main, so that the CUDA.launch_configuration is effectively an infinitely recursive function? Did this work by chance so far? Or am I missing something?

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Apr 18, 2026

we should really do:

Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;

->

call_with_native(CUDA.cufunction, f.f, Tuple{tt.parameters[2:end]...}).fun;

since essentially the gist there is that within the reactant interp we can call into the native interp result for

@giordano
Copy link
Copy Markdown
Member Author

diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl
index f1ce6e1dd..6f3c419ef 100644
--- a/ext/ReactantCUDAExt.jl
+++ b/ext/ReactantCUDAExt.jl
@@ -7,7 +7,8 @@ using Reactant:
     AnyConcretePJRTArray,
     MLIR,
     TracedRNumber,
-    ReactantPrecompilationException
+    ReactantPrecompilationException,
+    call_with_native
 using Reactant.Compiler: raising, LLVMFunc, llvm_compiler_cache
 using Reactant.Ops: @opcall
 
@@ -612,7 +613,7 @@ end
     f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0
 ) where {F,tt}
     return CUDA.launch_configuration(
-        Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;
+        call_with_native(CUDA.cufunction, f.f, Tuple{tt.parameters[2:end]...}).fun;
         shmem,
         max_threads,
     )

does fix the issue!

Copy link
Copy Markdown
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it passes LGTM

@giordano
Copy link
Copy Markdown
Member Author

I believe we need to wait LuxDL/Lux.jl#1696 for a clearer run, Lux integration tests are going to fail without that

@avik-pal
Copy link
Copy Markdown
Collaborator

JuliaRegistries/General#153282

@avik-pal avik-pal merged commit 52def24 into main Apr 19, 2026
117 of 126 checks passed
@avik-pal avik-pal deleted the mg/cuda-6 branch April 19, 2026 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants