Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ for runtime in (:PJRT, :IFRT)
initialized::Bool
clients::Dict{String,$(runtime).Client}
default_client::$(runtime).Client
attempted_backends::Set{String}

function $(backend_state)(
initialized::Bool=false,
clients::Dict{String,$(runtime).Client}=Dict{String,$(runtime).Client}(),
default_client::$(runtime).Client=$(runtime).NullClient,
attempted_backends::Set{String}=Set{String}(),
)
return finalizer(
finalize_backend_state, new(initialized, clients, default_client)
finalize_backend_state,
new(initialized, clients, default_client, attempted_backends),
)
end
end
Expand All @@ -82,6 +85,7 @@ for runtime in (:PJRT, :IFRT)
free_client(client)
end
empty!(state.clients)
empty!(state.attempted_backends)
state.default_client = $(runtime).NullClient
return nothing
end
Expand Down Expand Up @@ -126,15 +130,34 @@ end

function client(backend::String)
if backend == "gpu"
if haskey(global_backend_state.clients, "cuda")
backend = "cuda"
elseif haskey(global_backend_state.clients, "metal")
backend = "metal"
elseif haskey(global_backend_state.clients, "rocm")
backend = "rocm"
else
error("No GPU client found")
for b in ["cuda", "metal", "rocm"]
if haskey(global_backend_state.clients, b)
return global_backend_state.clients[b]
end
end

# If none found, check if we've attempted all of them
gpu_backends = normalize_backend_name("gpu")
if any(b ∉ global_backend_state.attempted_backends for b in gpu_backends)
with(BACKENDS_TO_INITIALIZE => gpu_backends) do
initialize_default_clients!(global_backend_state)
end
union!(global_backend_state.attempted_backends, gpu_backends)
return client("gpu") # Try again
end

error("No GPU client found")
end

if (
!haskey(global_backend_state.clients, backend) &&
backend ∉ global_backend_state.attempted_backends
)
backends = normalize_backend_name(backend)
with(BACKENDS_TO_INITIALIZE => backends) do
initialize_default_clients!(global_backend_state)
end
union!(global_backend_state.attempted_backends, backends)
end
return global_backend_state.clients[backend]
end
Expand Down Expand Up @@ -264,8 +287,15 @@ for runtime in (:PJRT, :IFRT)
was_initialized;
allow_initialization=backend -> begin
Reactant.precompiling() && return backend.platform_name == "cpu"
backends_to_initialize === missing && return true
return backend.platform_name in backends_to_initialize
will_try_initializing = false
if backends_to_initialize === missing
will_try_initializing = true
else
will_try_initializing = backend.platform_name in backends_to_initialize
end
will_try_initializing &&
push!(state.attempted_backends, backend.platform_name)
return will_try_initializing
end,
node_id=global_state.process_id,
num_nodes=global_state.num_processes,
Expand Down
Loading