@@ -8,6 +8,7 @@ using EnumX: @enumx
88using Enzyme: Compiler
99using Preferences: load_preference
1010using UUIDs: UUID
11+ using ScopedValues: ScopedValue, with
1112
1213using Setfield: Setfield, @set!
1314
@@ -48,6 +49,8 @@ include("IFRT/IFRT.jl")
4849
4950include (" CompileOptions.jl" )
5051
52+ const BACKENDS_TO_INITIALIZE = ScopedValue {Union{Missing,Set{String}}} (missing )
53+
5154abstract type AbstractBackendState end
5255
5356function finalize_backend_state end
@@ -116,12 +119,19 @@ function cleanup_backend_state()
116119 return nothing
117120end
118121
122+ function normalize_backend_name (backend:: String )
123+ backend == " gpu" && return Set ([" cuda" , " metal" , " rocm" ])
124+ return Set ([backend])
125+ end
126+
119127function client (backend:: String )
120128 if backend == " gpu"
121129 if haskey (global_backend_state. clients, " cuda" )
122130 backend = " cuda"
123131 elseif haskey (global_backend_state. clients, " metal" )
124132 backend = " metal"
133+ elseif haskey (global_backend_state. clients, " rocm" )
134+ backend = " rocm"
125135 else
126136 error (" No GPU client found" )
127137 end
@@ -145,7 +155,9 @@ function set_default_backend(backend::AbstractClient)
145155end
146156
147157function set_default_backend (backend:: String )
148- global_backend_state. default_client = client (backend)
158+ with (BACKENDS_TO_INITIALIZE => normalize_backend_name (backend)) do
159+ global_backend_state. default_client = client (backend)
160+ end
149161 return nothing
150162end
151163
235247
236248for runtime in (:PJRT , :IFRT )
237249 @eval function initialize_default_clients! (state:: $ (Symbol (runtime, :BackendState )))
250+ backends_to_initialize = BACKENDS_TO_INITIALIZE[]
251+
238252 was_initialized = state. initialized
239253 state. initialized = true
240254 distributed_runtime_client = if global_state. num_processes > 1
@@ -249,10 +263,9 @@ for runtime in (:PJRT, :IFRT)
249263 state,
250264 was_initialized;
251265 allow_initialization= backend -> begin
252- if Reactant. precompiling ()
253- return backend. platform_name == " cpu"
254- end
255- return true
266+ Reactant. precompiling () && return backend. platform_name == " cpu"
267+ backends_to_initialize === missing && return true
268+ return backend. platform_name in backends_to_initialize
256269 end ,
257270 node_id= global_state. process_id,
258271 num_nodes= global_state. num_processes,
0 commit comments