@@ -62,14 +62,17 @@ for runtime in (:PJRT, :IFRT)
6262 initialized:: Bool
6363 clients:: Dict{String,$(runtime).Client}
6464 default_client:: $ (runtime). Client
65+ attempted_backends:: Set{String}
6566
6667 function $ (backend_state)(
6768 initialized:: Bool = false ,
6869 clients:: Dict{String,$(runtime).Client} = Dict {String,$(runtime).Client} (),
6970 default_client:: $ (runtime). Client= $ (runtime). NullClient,
71+ attempted_backends:: Set{String} = Set {String} (),
7072 )
7173 return finalizer (
72- finalize_backend_state, new (initialized, clients, default_client)
74+ finalize_backend_state,
75+ new (initialized, clients, default_client, attempted_backends),
7376 )
7477 end
7578 end
@@ -82,6 +85,7 @@ for runtime in (:PJRT, :IFRT)
8285 free_client (client)
8386 end
8487 empty! (state. clients)
88+ empty! (state. attempted_backends)
8589 state. default_client = $ (runtime). NullClient
8690 return nothing
8791 end
@@ -126,15 +130,34 @@ end
126130
127131function client (backend:: String )
128132 if backend == " gpu"
129- if haskey (global_backend_state. clients, " cuda" )
130- backend = " cuda"
131- elseif haskey (global_backend_state. clients, " metal" )
132- backend = " metal"
133- elseif haskey (global_backend_state. clients, " rocm" )
134- backend = " rocm"
135- else
136- error (" No GPU client found" )
133+ for b in [" cuda" , " metal" , " rocm" ]
134+ if haskey (global_backend_state. clients, b)
135+ return global_backend_state. clients[b]
136+ end
137+ end
138+
139+ # If none found, check if we've attempted all of them
140+ gpu_backends = normalize_backend_name (" gpu" )
141+ if any (b ∉ global_backend_state. attempted_backends for b in gpu_backends)
142+ with (BACKENDS_TO_INITIALIZE => gpu_backends) do
143+ initialize_default_clients! (global_backend_state)
144+ end
145+ union! (global_backend_state. attempted_backends, gpu_backends)
146+ return client (" gpu" ) # Try again
147+ end
148+
149+ error (" No GPU client found" )
150+ end
151+
152+ if (
153+ ! haskey (global_backend_state. clients, backend) &&
154+ backend ∉ global_backend_state. attempted_backends
155+ )
156+ backends = normalize_backend_name (backend)
157+ with (BACKENDS_TO_INITIALIZE => backends) do
158+ initialize_default_clients! (global_backend_state)
137159 end
160+ union! (global_backend_state. attempted_backends, backends)
138161 end
139162 return global_backend_state. clients[backend]
140163end
@@ -264,8 +287,15 @@ for runtime in (:PJRT, :IFRT)
264287 was_initialized;
265288 allow_initialization= backend -> begin
266289 Reactant. precompiling () && return backend. platform_name == " cpu"
267- backends_to_initialize === missing && return true
268- return backend. platform_name in backends_to_initialize
290+ will_try_initializing = false
291+ if backends_to_initialize === missing
292+ will_try_initializing = true
293+ else
294+ will_try_initializing = backend. platform_name in backends_to_initialize
295+ end
296+ will_try_initializing &&
297+ push! (state. attempted_backends, backend. platform_name)
298+ return will_try_initializing
269299 end ,
270300 node_id= global_state. process_id,
271301 num_nodes= global_state. num_processes,
0 commit comments