|
221 | 221 | function cloud_tpu_init!() |
222 | 222 | libtpu_dir = get_libtpu_dir() |
223 | 223 | num_tpu_chips, tpu_version = num_available_tpu_chips_and_device_id() |
| 224 | + if num_tpu_chips == 0 |
| 225 | + ENV["TPU_SKIP_MDS_QUERY"] = "1" |
| 226 | + end |
| 227 | + |
224 | 228 | if ( |
225 | 229 | tpu_version != TPUVersion.Unknown && |
226 | 230 | tpu_version ≥ TPUVersion.v5e && |
@@ -265,14 +269,14 @@ end |
265 | 269 |
|
266 | 270 | const _TPU_METADATA_RESPONSE_CODE_SUCCESS = 200 |
267 | 271 |
|
| 272 | +function skip_mds_query() |
| 273 | + return haskey(ENV, "TPU_SKIP_MDS_QUERY") && parse(Bool, ENV["TPU_SKIP_MDS_QUERY"]) |
| 274 | +end |
| 275 | + |
268 | 276 | function get_metadata(key) |
269 | 277 | # Based on https://github.com/tensorflow/tensorflow/pull/40317 |
270 | 278 | gce_metadata_endpoint = |
271 | | - "http://" * get( |
272 | | - ENV, |
273 | | - "GCE_METADATA_IP", |
274 | | - get(ENV, "GCE_METADATA_HOST", "metadata.google.internal"), |
275 | | - ) |
| 279 | + "http://" * get(ENV, "GCE_METADATA_IP", "metadata.google.internal") |
276 | 280 | @debug "Getting metadata for key: $(key)" gce_metadata_endpoint |
277 | 281 | retry_count = 0 |
278 | 282 | retry_seconds = parse(Float64, get(ENV, "REACTANT_GCE_METADATA_RETRY_SECONDS", "0.5")) |
|
308 | 312 | function get_tpu_env_value(key) |
309 | 313 | haskey(ENV, key) && return ENV[key] |
310 | 314 |
|
| 315 | + skip_mds_query() && return nothing |
| 316 | + |
311 | 317 | tpu_env_data = first(get_metadata("tpu-env")) |
312 | 318 | key_value_pairs = split(tpu_env_data, "\n") |
313 | 319 | for key_value_pair in key_value_pairs |
|
0 commit comments