Skip to content

Commit 2017a2a

Browse files
committed
fix: support TPU_SKIP_MDS_QUERY
1 parent 3fcfae7 commit 2017a2a

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

src/Distributed.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ function is_env_present(::GceTPUCluster)
381381
return false
382382
end
383383

384-
if haskey(ENV, "TPU_SKIP_MDS_QUERY")
384+
if Accelerators.TPU.skip_mds_query()
385385
@debug "TPU_SKIP_MDS_QUERY is set to True, so it's probably not a GCE TPU cluster."
386386
return false
387387
end

src/accelerators/TPU.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ end
221221
function cloud_tpu_init!()
222222
libtpu_dir = get_libtpu_dir()
223223
num_tpu_chips, tpu_version = num_available_tpu_chips_and_device_id()
224+
if num_tpu_chips == 0
225+
ENV["TPU_SKIP_MDS_QUERY"] = "1"
226+
end
227+
224228
if (
225229
tpu_version != TPUVersion.Unknown &&
226230
tpu_version TPUVersion.v5e &&
@@ -265,14 +269,14 @@ end
265269

266270
const _TPU_METADATA_RESPONSE_CODE_SUCCESS = 200
267271

272+
function skip_mds_query()
273+
return haskey(ENV, "TPU_SKIP_MDS_QUERY") && parse(Bool, ENV["TPU_SKIP_MDS_QUERY"])
274+
end
275+
268276
function get_metadata(key)
269277
# Based on https://github.com/tensorflow/tensorflow/pull/40317
270278
gce_metadata_endpoint =
271-
"http://" * get(
272-
ENV,
273-
"GCE_METADATA_IP",
274-
get(ENV, "GCE_METADATA_HOST", "metadata.google.internal"),
275-
)
279+
"http://" * get(ENV, "GCE_METADATA_IP", "metadata.google.internal")
276280
@debug "Getting metadata for key: $(key)" gce_metadata_endpoint
277281
retry_count = 0
278282
retry_seconds = parse(Float64, get(ENV, "REACTANT_GCE_METADATA_RETRY_SECONDS", "0.5"))
@@ -308,6 +312,8 @@ end
308312
function get_tpu_env_value(key)
309313
haskey(ENV, key) && return ENV[key]
310314

315+
skip_mds_query() && return nothing
316+
311317
tpu_env_data = first(get_metadata("tpu-env"))
312318
key_value_pairs = split(tpu_env_data, "\n")
313319
for key_value_pair in key_value_pairs

0 commit comments

Comments
 (0)