77import importlib .util
88import inspect
99import platform
10+ import sys
1011from pathlib import Path
1112from typing import Any , Callable
1213
1314import pytest
1415
1516import torch
17+
18+ from _utils_internal import is_npu_available
1619from packaging import version
1720
1821from tensordict import (
5053
5154_IS_OSX = platform .system () == "Darwin"
5255
56+ npu_device_count = 0
57+ if torch .cuda .is_available ():
58+ cur_device = "cuda"
59+ elif is_npu_available ():
60+ cur_device = "npu"
61+ npu_device_count = torch .npu .device_count ()
62+
5363
64+ @pytest .mark .skipif (
65+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
66+ )
5467def test_vmap_compile ():
5568 # Since we monkey patch vmap we need to make sure compile is happy with it
5669 def func (x , y ):
@@ -67,6 +80,9 @@ def func(x, y):
6780@pytest .mark .skipif (
6881 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
6982)
83+ @pytest .mark .skipif (
84+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
85+ )
7086@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
7187class TestTD :
7288 def test_tensor_output (self , mode ):
@@ -266,7 +282,7 @@ def make_td_with_names(data):
266282 )
267283 @pytest .mark .parametrize ("has_device" , [True , False ])
268284 def test_to (self , has_device , mode ):
269- device = "cuda :0"
285+ device = f" { cur_device } :0"
270286
271287 def test_to_device (td ):
272288 return td .to (device )
@@ -283,6 +299,10 @@ def test_to_device(td):
283299 assert td_device_c .batch_size == td .batch_size
284300 assert td_device_c .device == torch .device (device )
285301
302+ @pytest .mark .skipif (
303+ is_npu_available (),
304+ reason = "torch.device in torch.compile is not supported on NPU currently." ,
305+ )
286306 def test_lock (self , mode ):
287307 def locked_op (td ):
288308 # Adding stuff uses cache, check that this doesn't break
@@ -357,6 +377,9 @@ class MyClass:
357377@pytest .mark .skipif (
358378 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
359379)
380+ @pytest .mark .skipif (
381+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
382+ )
360383@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
361384class TestTC :
362385 def test_tc_tensor_output (self , mode ):
@@ -553,7 +576,7 @@ def clone(td: TensorDict):
553576 )
554577 @pytest .mark .parametrize ("has_device" , [True , False ])
555578 def test_tc_to (self , has_device , mode ):
556- device = "cuda :0"
579+ device = f" { cur_device } :0"
557580
558581 def test_to_device (tc ):
559582 return tc .to (device )
@@ -570,6 +593,10 @@ def test_to_device(tc):
570593 assert tc_device_c .batch_size == data .batch_size
571594 assert tc_device_c .device == torch .device (device )
572595
596+ @pytest .mark .skipif (
597+ is_npu_available (),
598+ reason = "torch.device in torch.compile is not supported on NPU currently." ,
599+ )
573600 def test_tc_lock (self , mode ):
574601 def locked_op (tc ):
575602 # Adding stuff uses cache, check that this doesn't break
@@ -621,6 +648,9 @@ def func_c_mytd():
621648@pytest .mark .skipif (
622649 TORCH_VERSION < version .parse ("2.4.0" ), reason = "requires torch>=2.4"
623650)
651+ @pytest .mark .skipif (
652+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
653+ )
624654@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
625655class TestNN :
626656 def test_func (self , mode ):
@@ -725,6 +755,9 @@ def test_prob_module_with_kwargs(self, mode):
725755@pytest .mark .skipif (
726756 TORCH_VERSION <= version .parse ("2.4.0" ), reason = "requires torch>2.4"
727757)
758+ @pytest .mark .skipif (
759+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
760+ )
728761@pytest .mark .parametrize ("mode" , [None , "reduce-overhead" ])
729762class TestFunctional :
730763 def test_functional_error (self , mode ):
@@ -1015,6 +1048,9 @@ def to_numpy(tensor):
10151048 (TORCH_VERSION <= version .parse ("2.7.0" )) and _IS_OSX ,
10161049 reason = "requires torch>=2.7 ons OSX" ,
10171050)
1051+ @pytest .mark .skipif (
1052+ sys .version_info > (3 , 14 ), reason = "torch.compile is not supported on python 3.14+ "
1053+ )
10181054@pytest .mark .parametrize ("compiled" , [False , True ])
10191055class TestCudaGraphs :
10201056 @pytest .fixture (scope = "class" , autouse = True )
@@ -1239,7 +1275,7 @@ class TestCompileNontensor:
12391275 # Same issue with the decorator @tensorclass version
12401276 @pytest .fixture (scope = "class" )
12411277 def data (self ):
1242- return torch .zeros ((4 , 3 ), device = "cuda" )
1278+ return torch .zeros ((4 , 3 ), device = cur_device )
12431279
12441280 class TensorClassWithNonTensorData (TensorClass ["nocast" ]):
12451281 tensor : torch .Tensor
@@ -1257,13 +1293,13 @@ def fn_no_device(self, data):
12571293
12581294 def fn_with_device (self , data ):
12591295 a = self .TensorClassWithNonTensorData (
1260- tensor = data , non_tensor_data = 1 , batch_size = [4 ], device = "cuda"
1296+ tensor = data , non_tensor_data = 1 , batch_size = [4 ], device = cur_device
12611297 )
12621298 return a .tensor
12631299
12641300 def fn_with_device_without_batch_size (self , data ):
12651301 a = self .TensorClassWithNonTensorData (
1266- tensor = data , non_tensor_data = 1 , device = "cuda"
1302+ tensor = data , non_tensor_data = 1 , device = cur_device
12671303 )
12681304 return a .tensor
12691305
0 commit comments