Skip to content

Commit 58d7ed5

Browse files
committed
[Feature] Address PR review comments
Remove _SAFE_SHADOW_NAMES -- all field names that shadow TensorDict attributes now require explicit TypedTensorDict["shadow"]. Add Python version comments to try/except import fallbacks. Made-with: Cursor
1 parent fe63bfb commit 58d7ed5

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

tensordict/typedtensordict.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tensordict.base import NO_DEFAULT
1515

1616
try:
17+
# Python 3.11+ (PEP 681)
1718
from typing import dataclass_transform
1819
except ImportError:
1920

@@ -25,13 +26,11 @@ def _identity(cls):
2526

2627

2728
try:
29+
# Python 3.11+ (PEP 655)
2830
from typing import NotRequired
2931
except ImportError:
3032
from typing_extensions import NotRequired # noqa: F401
3133

32-
# Fields on TensorDict that are safe to shadow without the "shadow" option.
33-
_SAFE_SHADOW_NAMES = frozenset({"_is_non_tensor", "data"})
34-
3534
# Annotation names that are class-level metadata, not user fields.
3635
_META_FIELDS = frozenset(
3736
{
@@ -262,7 +261,7 @@ def __new__(
262261
if not cls._shadow:
263262
td_dir = _get_td_dir()
264263
for attr in expected:
265-
if attr in td_dir and attr not in _SAFE_SHADOW_NAMES:
264+
if attr in td_dir:
266265
raise AttributeError(
267266
f"Field '{attr}' shadows a TensorDict attribute. "
268267
f"Use TypedTensorDict['shadow'] to allow this."
@@ -273,9 +272,9 @@ def __new__(
273272
cls.__optional_keys__ = optional
274273

275274
# Generate properties for fields that clash with TensorDict attributes
276-
# so they override the parent's version. When shadow=False, only
277-
# _SAFE_SHADOW_NAMES fields (like "data") get properties; the rest
278-
# were already rejected above.
275+
# so they override the parent's version. When shadow=False these
276+
# fields were already rejected above, so this only runs when
277+
# shadow=True.
279278
td_dir = _get_td_dir()
280279
for attr in expected:
281280
if attr in td_dir:

test/test_typedtensordict.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -520,16 +520,13 @@ class Empty(TypedTensorDict):
520520
assert e.batch_size == torch.Size([3])
521521
assert len(e.keys()) == 0
522522

523-
def test_safe_shadow_data_without_option(self):
524-
"""'data' is in _SAFE_SHADOW_NAMES so it works without shadow=True."""
525-
526-
class WithData(TypedTensorDict):
527-
data: Tensor
528-
x: Tensor
523+
def test_data_requires_shadow(self):
524+
"""'data' shadows a TensorDict attribute and requires shadow=True."""
525+
with pytest.raises(AttributeError, match="shadows a TensorDict attribute"):
529526

530-
wd = WithData(data=torch.randn(3, 2), x=torch.randn(3, 4), batch_size=[3])
531-
assert wd.data.shape == (3, 2)
532-
assert wd.x.shape == (3, 4)
527+
class WithData(TypedTensorDict):
528+
data: Tensor
529+
x: Tensor
533530

534531
def test_batch_iter(self):
535532
"""__iter__ on TypedTensorDict iterates over batch dimension."""

0 commit comments

Comments
 (0)