Skip to content

Commit fbb38fa

Browse files
authored
fix: store list of attempted initializations (#2822)
1 parent 20cbf71 commit fbb38fa

1 file changed

Lines changed: 41 additions & 11 deletions

File tree

src/xla/XLA.jl

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,17 @@ for runtime in (:PJRT, :IFRT)
6262
initialized::Bool
6363
clients::Dict{String,$(runtime).Client}
6464
default_client::$(runtime).Client
65+
attempted_backends::Set{String}
6566

6667
function $(backend_state)(
6768
initialized::Bool=false,
6869
clients::Dict{String,$(runtime).Client}=Dict{String,$(runtime).Client}(),
6970
default_client::$(runtime).Client=$(runtime).NullClient,
71+
attempted_backends::Set{String}=Set{String}(),
7072
)
7173
return finalizer(
72-
finalize_backend_state, new(initialized, clients, default_client)
74+
finalize_backend_state,
75+
new(initialized, clients, default_client, attempted_backends),
7376
)
7477
end
7578
end
@@ -82,6 +85,7 @@ for runtime in (:PJRT, :IFRT)
8285
free_client(client)
8386
end
8487
empty!(state.clients)
88+
empty!(state.attempted_backends)
8589
state.default_client = $(runtime).NullClient
8690
return nothing
8791
end
@@ -126,15 +130,34 @@ end
126130

127131
function client(backend::String)
128132
if backend == "gpu"
129-
if haskey(global_backend_state.clients, "cuda")
130-
backend = "cuda"
131-
elseif haskey(global_backend_state.clients, "metal")
132-
backend = "metal"
133-
elseif haskey(global_backend_state.clients, "rocm")
134-
backend = "rocm"
135-
else
136-
error("No GPU client found")
133+
for b in ["cuda", "metal", "rocm"]
134+
if haskey(global_backend_state.clients, b)
135+
return global_backend_state.clients[b]
136+
end
137+
end
138+
139+
# If none found, check if we've attempted all of them
140+
gpu_backends = normalize_backend_name("gpu")
141+
if any(b global_backend_state.attempted_backends for b in gpu_backends)
142+
with(BACKENDS_TO_INITIALIZE => gpu_backends) do
143+
initialize_default_clients!(global_backend_state)
144+
end
145+
union!(global_backend_state.attempted_backends, gpu_backends)
146+
return client("gpu") # Try again
147+
end
148+
149+
error("No GPU client found")
150+
end
151+
152+
if (
153+
!haskey(global_backend_state.clients, backend) &&
154+
backend global_backend_state.attempted_backends
155+
)
156+
backends = normalize_backend_name(backend)
157+
with(BACKENDS_TO_INITIALIZE => backends) do
158+
initialize_default_clients!(global_backend_state)
137159
end
160+
union!(global_backend_state.attempted_backends, backends)
138161
end
139162
return global_backend_state.clients[backend]
140163
end
@@ -264,8 +287,15 @@ for runtime in (:PJRT, :IFRT)
264287
was_initialized;
265288
allow_initialization=backend -> begin
266289
Reactant.precompiling() && return backend.platform_name == "cpu"
267-
backends_to_initialize === missing && return true
268-
return backend.platform_name in backends_to_initialize
290+
will_try_initializing = false
291+
if backends_to_initialize === missing
292+
will_try_initializing = true
293+
else
294+
will_try_initializing = backend.platform_name in backends_to_initialize
295+
end
296+
will_try_initializing &&
297+
push!(state.attempted_backends, backend.platform_name)
298+
return will_try_initializing
269299
end,
270300
node_id=global_state.process_id,
271301
num_nodes=global_state.num_processes,

0 commit comments

Comments
 (0)