Skip to content

Commit 321db55

Browse files
committed
XPU support for float8 dtensor test
1 parent 5cc2ef9 commit 321db55

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

test/float8/test_dtensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,20 @@
4545

4646
torch.set_float32_matmul_precision("high")
4747

48+
_GPU_DEVICE = (
49+
str(torch.accelerator.current_accelerator())
50+
if torch.accelerator.is_available()
51+
else "no_gpu"
52+
)
53+
4854

4955
def setup_distributed():
5056
world_size = int(os.environ.get("WORLD_SIZE", -1))
51-
device_mesh = init_device_mesh("cuda", (world_size,))
57+
device_mesh = init_device_mesh(_GPU_DEVICE, (world_size,))
5258
# seed must be the same in all processes
5359
torch.manual_seed(1)
5460
local_rank = torch.distributed.get_rank()
55-
torch.cuda.set_device(local_rank)
61+
torch.get_device_module(_GPU_DEVICE).set_device(local_rank)
5662
return device_mesh
5763

5864

@@ -213,7 +219,7 @@ def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
213219

214220
def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
215221
torch.manual_seed(42)
216-
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()
222+
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).to(_GPU_DEVICE)
217223
convert_to_float8_training(
218224
model,
219225
config=Float8LinearConfig(

test/float8/test_dtensor.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
# terminate script on first error
99
set -e
1010

11-
if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
12-
echo "Skipping test_dtensor.sh because no CUDA devices are available."
11+
if ! python - <<'PY'
12+
import sys
13+
import torch
14+
has_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
15+
has_cuda = torch.cuda.is_available()
16+
sys.exit(0 if (has_xpu or has_cuda) else 1)
17+
PY
18+
then
19+
echo "Skipping test_dtensor.sh because no XPU/CUDA devices are available."
1320
exit
1421
fi
1522

0 commit comments

Comments
 (0)