22
33export @cuda , cudaconvert, cufunction, dynamic_cufunction, nextwarp, prevwarp
44@public maxthreads, registers, memory, version, KernelAdaptor
5+ @public AbstractBackend, LLVMBackend, DefaultBackend, kernel_convert, kernel_compile
6+
7+
8+ # # backend dispatch
9+
10+ """
11+ AbstractBackend
12+
13+ Abstract supertype for `@cuda` backend dispatch. The default backend is
14+ [`LLVMBackend`](@ref), which compiles SIMT/PTX kernels via
15+ [`cufunction`](@ref). Other backends (e.g. Tile IR via cuTile.jl) register
16+ a subtype and define methods for [`kernel_convert`](@ref) and
17+ [`kernel_compile`](@ref); `@cuda backend=...` then routes through them.
18+
19+ `@cuda backend=...` accepts either an `AbstractBackend` instance or a
20+ module that defines `DefaultBackend()` returning one (e.g.
21+ `@cuda backend=cuTile ...` resolves to `cuTile.DefaultBackend()`).
22+ """
23+ abstract type AbstractBackend end
24+
25+ """
26+ LLVMBackend()
27+
28+ Default `@cuda` backend. Compiles SIMT/PTX kernels via [`cufunction`](@ref)
29+ and converts arguments via [`cudaconvert`](@ref).
30+ """
31+ struct LLVMBackend <: AbstractBackend end
32+
33+ """
34+ DefaultBackend()
35+
36+ Returns the default `@cuda` backend for this module ([`LLVMBackend`](@ref)).
37+ This makes `@cuda backend=CUDA ...` (or `backend=CUDACore`) resolve to
38+ [`LLVMBackend`](@ref), mirroring the convention used by other backend
39+ packages (e.g. `@cuda backend=cuTile ...` resolves to `cuTile.DefaultBackend()`).
40+ """
41+ DefaultBackend () = LLVMBackend ()
42+
43+ """
44+ kernel_convert(backend, x)
45+
46+ Convert a host-side launch argument to its kernel-side form. The default
47+ implementation for [`LLVMBackend`](@ref) forwards to [`cudaconvert`](@ref);
48+ other backends override to produce backend-specific argument types.
49+ """
50+ kernel_convert (:: LLVMBackend , x) = cudaconvert (x)
51+
52+ """
53+ kernel_compile(backend, f, tt::Type{<:Tuple}; kwargs...) -> AbstractKernel
54+
55+ Compile a function for the given backend. Returns an [`AbstractKernel`](@ref)
56+ callable as `kernel(args...; launch_kwargs...)` to launch on the GPU. The
57+ default implementation for [`LLVMBackend`](@ref) is [`cufunction`](@ref).
58+ """
59+ kernel_compile (:: LLVMBackend , f:: F , tt:: TT = Tuple{}; kwargs... ) where {F,TT} =
60+ cufunction (f, tt; kwargs... )
561
662
763# # high-level @cuda interface
864
9- const MACRO_KWARGS = [:dynamic , :launch ]
65+ const MACRO_KWARGS = [:dynamic , :launch , :backend ]
1066const COMPILER_KWARGS = [:kernel , :name , :always_inline , :minthreads , :maxthreads , :blocks_per_sm , :maxregs , :fastmath , :cap , :ptx ]
1167const LAUNCH_KWARGS = [:cooperative , :blocks , :threads , :clustersize , :shmem , :stream ]
1268
@@ -24,6 +80,10 @@ Several keyword arguments are supported that influence the behavior of `@cuda`.
2480- `launch`: whether to launch this kernel, defaults to `true`. If `false` the returned
2581 kernel object should be launched by calling it and passing arguments again.
2682- `dynamic`: use dynamic parallelism to launch device-side kernels, defaults to `false`.
83+ - `backend`: which compiler backend to use, defaults to [`LLVMBackend`](@ref). Either an
84+ [`AbstractBackend`](@ref) instance or a module that defines `DefaultBackend()` (e.g.
85+ `backend=CUDA` resolves to `CUDA.DefaultBackend()`). Backend-specific compiler kwargs
86+ not recognized by `@cuda` itself are forwarded to [`kernel_compile`](@ref).
2787- arguments that influence kernel compilation: see [`cufunction`](@ref) and
2888 [`dynamic_cufunction`](@ref)
2989- arguments that influence kernel launch: see [`CUDACore.HostKernel`](@ref) and
@@ -50,17 +110,16 @@ macro cuda(ex...)
50110 code = quote end
51111 vars, var_exprs = assign_args! (code, args)
52112
53- # group keyword argument
113+ # group keyword argument. Backend-specific compiler kwargs land in
114+ # `other_kwargs` and are forwarded to `kernel_compile`; the backend
115+ # validates them.
54116 macro_kwargs, compiler_kwargs, call_kwargs, other_kwargs =
55117 split_kwargs (kwargs, MACRO_KWARGS, COMPILER_KWARGS, LAUNCH_KWARGS)
56- if ! isempty (other_kwargs)
57- key,val = first (other_kwargs). args
58- throw (ArgumentError (" Unsupported keyword argument '$key '" ))
59- end
60118
61119 # handle keyword arguments that influence the macro's behavior
62120 dynamic = false
63121 launch = true
122+ backend_expr = :($ LLVMBackend ())
64123 for kwarg in macro_kwargs
65124 key:: Symbol , val = kwarg. args
66125 if key === :dynamic
@@ -69,6 +128,8 @@ macro cuda(ex...)
69128 elseif key === :launch
70129 isa (val, Bool) || throw (ArgumentError (" `launch` keyword argument to @cuda should be a constant value" ))
71130 launch = val:: Bool
131+ elseif key === :backend
132+ backend_expr = val
72133 else
73134 throw (ArgumentError (" Unsupported keyword argument '$key '" ))
74135 end
@@ -79,12 +140,14 @@ macro cuda(ex...)
79140
80141 # FIXME : macro hygiene wrt. escaping kwarg values (this broke with 1.5)
81142 # we esc() the whole thing now, necessitating gensyms...
82- @gensym f_var kernel_f kernel_args kernel_tt kernel
143+ @gensym f_var kernel_f kernel_args kernel_tt kernel backend backend_raw
83144 if dynamic
84145 # FIXME : we could probably somehow support kwargs with constant values by either
85146 # saving them in a global Dict here, or trying to pick them up from the Julia
86147 # IR when processing the dynamic parallelism marker
87148 isempty (compiler_kwargs) || error (" @cuda dynamic parallelism does not support compiler keyword arguments" )
149+ isempty (other_kwargs) ||
150+ error (" @cuda dynamic parallelism does not support backend-specific compiler keyword arguments" )
88151
89152 # dynamic, device-side kernel launch
90153 push! (code. args,
@@ -105,12 +168,19 @@ macro cuda(ex...)
105168 # while keeping the original arguments alive
106169 push! (code. args,
107170 quote
171+ # Accept either an `AbstractBackend` instance or a module
172+ # providing `DefaultBackend()` (e.g. `backend=cuTile`).
173+ # Inference folds the branch away on concretely-typed inputs.
174+ $ backend = let $ backend_raw = $ backend_expr
175+ $ backend_raw isa $ AbstractBackend ? $ backend_raw : $ backend_raw. DefaultBackend ()
176+ end
108177 $ f_var = $ f
109178 GC. @preserve $ (vars... ) $ f_var begin
110- $ kernel_f = $ cudaconvert ( $ f_var)
111- $ kernel_args = map ($ cudaconvert , ($ (var_exprs... ),))
179+ $ kernel_f = $ kernel_convert ( $ backend, $ f_var)
180+ $ kernel_args = map (x -> $ kernel_convert ( $ backend, x) , ($ (var_exprs... ),))
112181 $ kernel_tt = Tuple{map (Core. Typeof, $ kernel_args)... }
113- $ kernel = $ cufunction ($ kernel_f, $ kernel_tt; $ (compiler_kwargs... ))
182+ $ kernel = $ kernel_compile ($ backend, $ kernel_f, $ kernel_tt;
183+ $ (compiler_kwargs... ), $ (other_kwargs... ))
114184 if $ launch
115185 $ kernel ($ kernel_args... ; $ (call_kwargs... ), convert= Val (false ))
116186 end
@@ -239,10 +309,12 @@ The following keyword arguments are supported:
239309AbstractKernel
240310
241311function Base. show (io:: IO , k:: AbstractKernel{F,TT} ) where {F,TT}
242- print (io, " CUDACore.$(nameof (typeof (k))) ($(k. f) )" )
312+ T = typeof (k)
313+ print (io, " $(parentmodule (T)) .$(nameof (T)) ($(k. f) )" )
243314end
244315function Base. show (io:: IO , :: MIME"text/plain" , k:: AbstractKernel{F,TT} ) where {F,TT}
245- print (io, " CUDACore.$(nameof (typeof (k))) for $(k. f) ($(join (TT. parameters, " , " )) )" )
316+ T = typeof (k)
317+ print (io, " $(parentmodule (T)) .$(nameof (T)) for $(k. f) ($(join (TT. parameters, " , " )) )" )
246318end
247319
248320@inline @generated function (kernel:: AbstractKernel{F,TT} )(args:: Vararg{Any,N} ;
0 commit comments