Skip to content

Commit 010db35

Browse files
authored
fix: dont initialize all backends if specific backend is requested (#2778)
1 parent 9d0aca7 commit 010db35

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

src/xla/XLA.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using EnumX: @enumx
88
using Enzyme: Compiler
99
using Preferences: load_preference
1010
using UUIDs: UUID
11+
using ScopedValues: ScopedValue, with
1112

1213
using Setfield: Setfield, @set!
1314

@@ -48,6 +49,8 @@ include("IFRT/IFRT.jl")
4849

4950
include("CompileOptions.jl")
5051

52+
const BACKENDS_TO_INITIALIZE = ScopedValue{Union{Missing,Set{String}}}(missing)
53+
5154
abstract type AbstractBackendState end
5255

5356
function finalize_backend_state end
@@ -116,12 +119,19 @@ function cleanup_backend_state()
116119
return nothing
117120
end
118121

122+
function normalize_backend_name(backend::String)
123+
backend == "gpu" && return Set(["cuda", "metal", "rocm"])
124+
return Set([backend])
125+
end
126+
119127
function client(backend::String)
120128
if backend == "gpu"
121129
if haskey(global_backend_state.clients, "cuda")
122130
backend = "cuda"
123131
elseif haskey(global_backend_state.clients, "metal")
124132
backend = "metal"
133+
elseif haskey(global_backend_state.clients, "rocm")
134+
backend = "rocm"
125135
else
126136
error("No GPU client found")
127137
end
@@ -145,7 +155,9 @@ function set_default_backend(backend::AbstractClient)
145155
end
146156

147157
function set_default_backend(backend::String)
148-
global_backend_state.default_client = client(backend)
158+
with(BACKENDS_TO_INITIALIZE => normalize_backend_name(backend)) do
159+
global_backend_state.default_client = client(backend)
160+
end
149161
return nothing
150162
end
151163

@@ -235,6 +247,8 @@ end
235247

236248
for runtime in (:PJRT, :IFRT)
237249
@eval function initialize_default_clients!(state::$(Symbol(runtime, :BackendState)))
250+
backends_to_initialize = BACKENDS_TO_INITIALIZE[]
251+
238252
was_initialized = state.initialized
239253
state.initialized = true
240254
distributed_runtime_client = if global_state.num_processes > 1
@@ -249,10 +263,9 @@ for runtime in (:PJRT, :IFRT)
249263
state,
250264
was_initialized;
251265
allow_initialization=backend -> begin
252-
if Reactant.precompiling()
253-
return backend.platform_name == "cpu"
254-
end
255-
return true
266+
Reactant.precompiling() && return backend.platform_name == "cpu"
267+
backends_to_initialize === missing && return true
268+
return backend.platform_name in backends_to_initialize
256269
end,
257270
node_id=global_state.process_id,
258271
num_nodes=global_state.num_processes,

0 commit comments

Comments
 (0)