Skip to content

Commit 8bd223e

Browse files
UserChen666vmoens
andauthored
[Feature] Add more supports for NPU in addition to CUDA in previously supported use cases. (#1471)
Co-authored-by: chenhao388 <chenhao388@huawei.com> Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 5729045 commit 8bd223e

7 files changed

Lines changed: 133 additions & 36 deletions

File tree

tensordict/_td.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,10 +1802,7 @@ def split(
18021802
raise ValueError(
18031803
f"TensorDict.split: split_size must be positive, got {split_size}."
18041804
)
1805-
if split_size > max_size:
1806-
raise ValueError(
1807-
f"TensorDict.split: split_size ({split_size}) exceeds dimension size ({max_size})."
1808-
)
1805+
split_size = min(split_size, max_size)
18091806
segments = _create_segments_from_int(split_size, max_size)
18101807
splits = [end - start for start, end in segments]
18111808
splits = {k: v.split(splits, dim) for k, v in self.items()}

tensordict/nn/tensorclass_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from collections.abc import Iterable
55
from dataclasses import Field
6-
from typing import Any, cast, Generic, get_args, get_origin, TypeVar
6+
from typing import Any, cast, Generic, get_args, get_origin, TypeVar, Union
77

88
from tensordict._td import TensorDict
99
from tensordict.nn.common import dispatch, TensorDictModuleBase
@@ -101,8 +101,8 @@ def forward(self, tensordict: TensorDict, *args, **kwargs) -> TensorDict:
101101
).to_tensordict()
102102

103103

104-
InputClass = TypeVar("InputClass", bound=(TensorClass | Tensor))
105-
OutputClass = TypeVar("OutputClass", bound=(TensorClass | Tensor))
104+
InputClass = TypeVar("InputClass", bound=Union[TensorClass, Tensor])
105+
OutputClass = TypeVar("OutputClass", bound=Union[TensorClass, Tensor])
106106

107107

108108
class TensorClassModuleBase(Generic[InputClass, OutputClass], ABC, nn.Module):

test/test_compile.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import importlib.util
88
import inspect
99
import platform
10+
import sys
1011
from pathlib import Path
1112
from typing import Any, Callable
1213

1314
import pytest
1415

1516
import torch
17+
18+
from _utils_internal import is_npu_available
1619
from packaging import version
1720

1821
from tensordict import (
@@ -50,7 +53,17 @@
5053

5154
_IS_OSX = platform.system() == "Darwin"
5255

56+
npu_device_count = 0
57+
if torch.cuda.is_available():
58+
cur_device = "cuda"
59+
elif is_npu_available():
60+
cur_device = "npu"
61+
npu_device_count = torch.npu.device_count()
62+
5363

64+
@pytest.mark.skipif(
65+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
66+
)
5467
def test_vmap_compile():
5568
# Since we monkey patch vmap we need to make sure compile is happy with it
5669
def func(x, y):
@@ -67,6 +80,9 @@ def func(x, y):
6780
@pytest.mark.skipif(
6881
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
6982
)
83+
@pytest.mark.skipif(
84+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
85+
)
7086
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
7187
class TestTD:
7288
def test_tensor_output(self, mode):
@@ -266,7 +282,7 @@ def make_td_with_names(data):
266282
)
267283
@pytest.mark.parametrize("has_device", [True, False])
268284
def test_to(self, has_device, mode):
269-
device = "cuda:0"
285+
device = f"{cur_device}:0"
270286

271287
def test_to_device(td):
272288
return td.to(device)
@@ -283,6 +299,10 @@ def test_to_device(td):
283299
assert td_device_c.batch_size == td.batch_size
284300
assert td_device_c.device == torch.device(device)
285301

302+
@pytest.mark.skipif(
303+
is_npu_available(),
304+
reason="torch.device in torch.compile is not supported on NPU currently.",
305+
)
286306
def test_lock(self, mode):
287307
def locked_op(td):
288308
# Adding stuff uses cache, check that this doesn't break
@@ -357,6 +377,9 @@ class MyClass:
357377
@pytest.mark.skipif(
358378
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
359379
)
380+
@pytest.mark.skipif(
381+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
382+
)
360383
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
361384
class TestTC:
362385
def test_tc_tensor_output(self, mode):
@@ -553,7 +576,7 @@ def clone(td: TensorDict):
553576
)
554577
@pytest.mark.parametrize("has_device", [True, False])
555578
def test_tc_to(self, has_device, mode):
556-
device = "cuda:0"
579+
device = f"{cur_device}:0"
557580

558581
def test_to_device(tc):
559582
return tc.to(device)
@@ -570,6 +593,10 @@ def test_to_device(tc):
570593
assert tc_device_c.batch_size == data.batch_size
571594
assert tc_device_c.device == torch.device(device)
572595

596+
@pytest.mark.skipif(
597+
is_npu_available(),
598+
reason="torch.device in torch.compile is not supported on NPU currently.",
599+
)
573600
def test_tc_lock(self, mode):
574601
def locked_op(tc):
575602
# Adding stuff uses cache, check that this doesn't break
@@ -621,6 +648,9 @@ def func_c_mytd():
621648
@pytest.mark.skipif(
622649
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
623650
)
651+
@pytest.mark.skipif(
652+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
653+
)
624654
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
625655
class TestNN:
626656
def test_func(self, mode):
@@ -725,6 +755,9 @@ def test_prob_module_with_kwargs(self, mode):
725755
@pytest.mark.skipif(
726756
TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4"
727757
)
758+
@pytest.mark.skipif(
759+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
760+
)
728761
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
729762
class TestFunctional:
730763
def test_functional_error(self, mode):
@@ -1015,6 +1048,9 @@ def to_numpy(tensor):
10151048
(TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX,
10161049
reason="requires torch>=2.7 ons OSX",
10171050
)
1051+
@pytest.mark.skipif(
1052+
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
1053+
)
10181054
@pytest.mark.parametrize("compiled", [False, True])
10191055
class TestCudaGraphs:
10201056
@pytest.fixture(scope="class", autouse=True)
@@ -1239,7 +1275,7 @@ class TestCompileNontensor:
12391275
# Same issue with the decorator @tensorclass version
12401276
@pytest.fixture(scope="class")
12411277
def data(self):
1242-
return torch.zeros((4, 3), device="cuda")
1278+
return torch.zeros((4, 3), device=cur_device)
12431279

12441280
class TensorClassWithNonTensorData(TensorClass["nocast"]):
12451281
tensor: torch.Tensor
@@ -1257,13 +1293,13 @@ def fn_no_device(self, data):
12571293

12581294
def fn_with_device(self, data):
12591295
a = self.TensorClassWithNonTensorData(
1260-
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
1296+
tensor=data, non_tensor_data=1, batch_size=[4], device=cur_device
12611297
)
12621298
return a.tensor
12631299

12641300
def fn_with_device_without_batch_size(self, data):
12651301
a = self.TensorClassWithNonTensorData(
1266-
tensor=data, non_tensor_data=1, device="cuda"
1302+
tensor=data, non_tensor_data=1, device=cur_device
12671303
)
12681304
return a.tensor
12691305

test/test_distributed.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
import torch
1313
from _pytest.fixtures import fixture
14+
from _utils_internal import is_npu_available
1415
from packaging import version
1516

1617
from packaging.version import parse
@@ -107,6 +108,70 @@ def test_fsdp_module(self, tmpdir):
107108
assert (TensorDict.load_memmap(tmpdir) == 1).all()
108109

109110

111+
@pytest.mark.skipif(
112+
not is_npu_available() or not torch.npu.device_count() > 2,
113+
reason="not enough npu devices",
114+
)
115+
class TestNPUFSDP:
116+
class MyDModule(nn.Module):
117+
def __init__(self):
118+
super().__init__()
119+
self.fc1 = nn.Linear(8, 8, bias=False)
120+
self.fc2 = nn.Linear(8, 8, bias=False)
121+
self.relu = nn.ReLU()
122+
for p in self.parameters():
123+
p.data.fill_(1.0)
124+
125+
def forward(self, input):
126+
return self.relu(self.fc1(input) + self.fc2(input))
127+
128+
@classmethod
129+
def make_module(cls, device=None):
130+
with (
131+
torch.device(f"npu:{device}") if device is not None else torch.device("npu")
132+
):
133+
my_module = cls.MyDModule()
134+
my_sharded_module = FSDP(my_module, device_id=device)
135+
return my_sharded_module
136+
137+
@classmethod
138+
def worker(cls, rank, path):
139+
os.environ["MASTER_ADDR"] = "localhost"
140+
os.environ["MASTER_PORT"] = "10017"
141+
142+
torch.distributed.init_process_group(
143+
"hccl",
144+
rank=rank,
145+
world_size=2,
146+
init_method="tcp://localhost:10017",
147+
)
148+
torch.npu.set_device(rank)
149+
module = cls.make_module(rank)
150+
dist.barrier()
151+
# cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
152+
# with FSDP.state_dict_type(module, StateDictType.SHARDED_STATE_DICT): #, cfg):
153+
# tdlogger.info(module.state_dict())
154+
155+
# td = TensorDict(module.state_dict(), []).unflatten_keys(".")
156+
td = TensorDict.from_module(module, use_state_dict=True)
157+
if rank == 0:
158+
td.memmap(path)
159+
dist.destroy_process_group()
160+
161+
def test_fsdp_module(self, tmpdir):
162+
try:
163+
mp.set_start_method("spawn")
164+
except Exception:
165+
tdlogger.info("start method already set to", mp.get_start_method())
166+
proc0 = mp.Process(target=self.worker, args=(0, tmpdir))
167+
proc1 = mp.Process(target=self.worker, args=(1, tmpdir))
168+
proc0.start()
169+
proc1.start()
170+
proc0.join(timeout=TIMEOUT)
171+
proc1.join(timeout=TIMEOUT)
172+
assert (TensorDict.load_memmap(tmpdir) == 1).all()
173+
174+
110175
# not using TorchVersion to make the comparison work with dev
111176
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
112177

test/test_nn.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121
import torch
22+
from _utils_internal import is_npu_available
2223

2324
from tensordict import (
2425
is_tensor_collection,
@@ -116,6 +117,17 @@
116117
)
117118

118119

120+
def get_device():
121+
device = torch.device("cpu")
122+
if torch.cuda.is_available():
123+
device = torch.device("cuda:0")
124+
elif is_npu_available():
125+
device = torch.device("npu:0")
126+
elif torch.mps.is_available():
127+
device = torch.device("mps:0")
128+
return device
129+
130+
119131
class TestInteractionType:
120132
def test_base(self):
121133
with set_interaction_type("DETERMINISTIC"):
@@ -2153,37 +2165,24 @@ def test_module_buffer():
21532165
if torch.cuda.device_count():
21542166
module.cuda()
21552167
assert module.td.device.type == "cuda"
2168+
elif is_npu_available():
2169+
module = module.to("npu:0")
2170+
assert module.td.device.type == "npu"
21562171

21572172

21582173
@pytest.mark.parametrize(
21592174
"original_device",
21602175
[
21612176
None,
21622177
torch.device("cpu"),
2163-
(
2164-
torch.device("cuda:0")
2165-
if torch.cuda.is_available()
2166-
else (
2167-
torch.device("mps:0")
2168-
if torch.mps.is_available()
2169-
else torch.device("cpu")
2170-
)
2171-
),
2178+
get_device(),
21722179
],
21732180
)
21742181
@pytest.mark.parametrize(
21752182
"new_device",
21762183
[
21772184
torch.device("cpu"),
2178-
(
2179-
torch.device("cuda:0")
2180-
if torch.cuda.is_available()
2181-
else (
2182-
torch.device("mps:0")
2183-
if torch.mps.is_available()
2184-
else torch.device("cpu")
2185-
)
2186-
),
2185+
get_device(),
21872186
],
21882187
)
21892188
@pytest.mark.parametrize("tc", [True, False], ids=["tc", "td"])

test/test_tensorclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pytest
2727
import tensordict.utils
2828
import torch
29+
from _utils_internal import is_npu_available
2930

3031
from tensordict import (
3132
assert_allclose_td,
@@ -45,6 +46,7 @@
4546
from tensordict._td import lazy_stack
4647
from tensordict.base import _GENERIC_NESTED_ERR
4748
from tensordict.tensorclass import from_dataclass
49+
4850
from torch import Tensor
4951

5052
_has_streaming = importlib.util.find_spec("streaming", None) is not None
@@ -2566,6 +2568,8 @@ def test_to(self):
25662568
td = self.get_nested()
25672569
if torch.cuda.is_available():
25682570
device = torch.device("cuda:0")
2571+
elif is_npu_available():
2572+
device = torch.device("npu:0")
25692573
else:
25702574
device = torch.device("cpu:1")
25712575
td_device = td.to(device)

test/test_tensordict.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2957,13 +2957,9 @@ def test_split_with_invalid_arguments(self):
29572957
td.split(1, 2)
29582958
with pytest.raises(IndexError, match="Incompatible dim"):
29592959
td.split(1, -3)
2960-
with pytest.raises(
2961-
RuntimeError, match="split_size must be a positive integer, but got 0."
2962-
):
2960+
with pytest.raises(ValueError, match="split_size must be positive, got 0."):
29632961
td.split(0, -1)
2964-
with pytest.raises(
2965-
RuntimeError, match="split_size must be a positive integer, but got -1."
2966-
):
2962+
with pytest.raises(ValueError, match="split_size must be positive, got -1."):
29672963
td.split(-1, -1)
29682964

29692965
def test_split_with_negative_dim(self):

0 commit comments

Comments
 (0)