Skip to content

Commit 7fba625

Browse files
authored
[Feature] TypedTensorDict (#1657)
1 parent 8c84dcb commit 7fba625

File tree

9 files changed

+1341
-8
lines changed

9 files changed

+1341
-8
lines changed

docs/source/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ API Reference
77
td
88
nn
99
tc
10+
ttd

docs/source/reference/ttd.rst

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
.. currentmodule:: tensordict
2+
3+
TypedTensorDict
4+
===============
5+
6+
:class:`~tensordict.TypedTensorDict` is a :class:`~tensordict.TensorDict` subclass
7+
with typed field declarations. It brings ``TypedDict``-style class definitions to
8+
``TensorDict``: you declare fields as class annotations and get typed construction,
9+
typed attribute access, inheritance, ``NotRequired`` fields, and ``**state`` spreading
10+
-- all while keeping every ``TensorDict`` operation available.
11+
12+
.. code-block:: python
13+
14+
>>> import torch
15+
>>> from tensordict import TypedTensorDict
16+
>>> from torch import Tensor
17+
>>>
18+
>>> class PredictorState(TypedTensorDict):
19+
... eta: Tensor
20+
... X: Tensor
21+
... beta: Tensor
22+
>>>
23+
>>> state = PredictorState(
24+
... eta=torch.randn(5, 3),
25+
... X=torch.randn(5, 4),
26+
... beta=torch.randn(5, 1),
27+
... batch_size=[5],
28+
... )
29+
>>> state.eta.shape
30+
torch.Size([5, 3])
31+
>>> state["X"].shape
32+
torch.Size([5, 4])
33+
34+
Why TypedTensorDict?
35+
--------------------
36+
37+
Typed pipelines often build up state one step at a time:
38+
39+
.. code-block:: python
40+
41+
class PredictorState(TypedTensorDict):
42+
eta: Tensor
43+
X: Tensor
44+
beta: Tensor
45+
46+
class ObservedState(PredictorState):
47+
y: Tensor
48+
mu: Tensor
49+
50+
def gaussian(state: PredictorState, std: float) -> ObservedState:
51+
eta = state.eta
52+
y = eta + torch.randn_like(eta) * std
53+
return ObservedState(**state, y=y, mu=eta, batch_size=state.batch_size)
54+
55+
Each stage inherits the previous one's fields and adds new ones. The
56+
``**state`` spreading pattern lets transition functions stay short regardless
57+
of how many fields the state has. And because ``TypedTensorDict`` **is** a
58+
``TensorDict``, every operation -- ``.to(device)``, ``.clone()``, slicing,
59+
``torch.stack``, ``memmap`` -- works at every stage.
60+
61+
TypedTensorDict vs TensorClass
62+
------------------------------
63+
64+
Both ``TypedTensorDict`` and ``TensorClass`` provide typed tensor containers.
65+
They share the same class-option syntax (``["shadow"]``, ``["frozen"]``, etc.)
66+
and both use ``@dataclass_transform()`` for IDE support. The key difference is
67+
in the underlying model:
68+
69+
.. list-table::
70+
:header-rows: 1
71+
:widths: 35 30 30
72+
73+
* - Feature
74+
- ``TypedTensorDict``
75+
- ``TensorClass``
76+
* - Inherits from
77+
- ``TensorDict`` directly
78+
- ``TensorCollection`` (wraps a ``TensorDict`` internally)
79+
* - Inheritance
80+
- Standard Python (``class Child(Parent): ...``)
81+
- Supported via metaclass
82+
* - ``**state`` spreading
83+
- Works natively (``MutableMapping``)
84+
- Requires manual field-by-field repacking
85+
* - ``state["key"]``
86+
- Works natively (``TensorDict.__getitem__``)
87+
- Raises ``ValueError`` -- use ``state.key`` or ``state.get("key")``
88+
* - ``NotRequired`` fields
89+
- Supported
90+
- Not supported
91+
* - Non-tensor fields
92+
- Not supported (tensor-only)
93+
- Supported (strings, ints, arbitrary objects)
94+
* - Custom methods
95+
- Supported (regular class methods)
96+
- Supported (regular class methods)
97+
* - ``@tensorclass`` decorator
98+
- Not needed (uses metaclass via inheritance)
99+
- Required (or ``class Foo(TensorClass): ...``)
100+
101+
**When to use which:**
102+
103+
- Use ``TypedTensorDict`` when you have a typed pipeline with progressive state
104+
accumulation, need ``**state`` spreading, or want direct ``TensorDict``
105+
interop without a wrapper layer.
106+
107+
- Use ``TensorClass`` when you need non-tensor fields (strings, metadata),
108+
custom ``__init__`` logic, or your codebase already uses ``@tensorclass``
109+
extensively.
110+
111+
Inheritance and field accumulation
112+
----------------------------------
113+
114+
Fields accumulate through the MRO. Each subclass adds its own fields while
115+
inheriting all parent fields:
116+
117+
.. code-block:: python
118+
119+
>>> from typing import NotRequired
120+
>>>
121+
>>> class PredictorState(TypedTensorDict):
122+
... eta: Tensor
123+
... X: Tensor
124+
... beta: Tensor
125+
>>>
126+
>>> class ObservedState(PredictorState):
127+
... y: Tensor
128+
... mu: Tensor
129+
... noise: NotRequired[Tensor]
130+
>>>
131+
>>> class SurvivalState(ObservedState):
132+
... event_time: Tensor
133+
... indicator: Tensor
134+
... observed_time: Tensor
135+
136+
>>> ObservedState.__required_keys__
137+
frozenset({'eta', 'X', 'beta', 'y', 'mu'})
138+
>>> ObservedState.__optional_keys__
139+
frozenset({'noise'})
140+
141+
Inheritance works as standard Python: ``isinstance(obs, PredictorState)``
142+
returns ``True`` for an ``ObservedState`` instance, and a function typed as
143+
``f(state: PredictorState)`` accepts any subclass.
144+
145+
NotRequired fields
146+
------------------
147+
148+
Mark fields as optional with :data:`~typing.NotRequired`:
149+
150+
.. code-block:: python
151+
152+
>>> from typing import NotRequired
153+
>>>
154+
>>> class ObservedState(PredictorState):
155+
... y: Tensor
156+
... mu: Tensor
157+
... noise: NotRequired[Tensor]
158+
159+
>>> obs = ObservedState(
160+
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
161+
... y=torch.randn(5, 3), mu=torch.randn(5, 3),
162+
... batch_size=[5],
163+
... )
164+
>>> "noise" in obs
165+
False
166+
167+
If a ``NotRequired`` field is not provided, it is simply absent from the
168+
underlying ``TensorDict``. Accessing it via attribute raises
169+
``AttributeError``.
170+
171+
Spreading (``**state``)
172+
-----------------------
173+
174+
Because ``TypedTensorDict`` is a ``MutableMapping``, the ``**`` operator
175+
unpacks it into keyword arguments. This makes state transitions concise:
176+
177+
.. code-block:: python
178+
179+
>>> state = PredictorState(
180+
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
181+
... batch_size=[5],
182+
... )
183+
>>> obs = ObservedState(
184+
... **state,
185+
... y=torch.randn(5, 3),
186+
... mu=torch.randn(5, 3),
187+
... batch_size=state.batch_size,
188+
... )
189+
>>> set(obs.keys()) == {"eta", "X", "beta", "y", "mu"}
190+
True
191+
192+
Adding a new field to a pipeline stage is one line in the class definition --
193+
no transition function needs updating.
194+
195+
Class options
196+
-------------
197+
198+
``TypedTensorDict`` supports the same bracket-syntax options as ``TensorClass``:
199+
200+
.. code-block:: python
201+
202+
class MyModel(TypedTensorDict["shadow"]):
203+
data: Tensor # "data" shadows TensorDict.data -- allowed
204+
205+
class Immutable(TypedTensorDict["frozen"]):
206+
x: Tensor # locked after construction
207+
208+
class Combined(TypedTensorDict["shadow", "frozen"]):
209+
data: Tensor
210+
211+
- ``"shadow"`` -- Allow field names that clash with ``TensorDict`` attributes.
212+
Without this, conflicting names raise ``AttributeError`` at class definition
213+
time.
214+
- ``"frozen"`` -- Lock the ``TensorDict`` after construction (read-only).
215+
- ``"autocast"`` -- Automatically cast assigned values.
216+
- ``"nocast"`` -- Disable type casting on assignment.
217+
- ``"tensor_only"`` -- Restrict fields to tensor types only.
218+
219+
Options propagate through inheritance: a subclass of a ``"frozen"`` class is
220+
also frozen.
221+
222+
TensorDict operations
223+
---------------------
224+
225+
Every ``TensorDict`` operation works on ``TypedTensorDict`` instances:
226+
227+
.. code-block:: python
228+
229+
>>> state = PredictorState(
230+
... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1),
231+
... batch_size=[5],
232+
... )
233+
>>> state.to("cpu").device
234+
device(type='cpu')
235+
>>> state.clone()["eta"].shape
236+
torch.Size([5, 3])
237+
>>> state[0:3].batch_size
238+
torch.Size([3])
239+
>>> torch.stack([state, state], dim=0).batch_size
240+
torch.Size([2, 5])
241+
242+
This includes ``.memmap()``, ``.apply()``, ``torch.cat``, ``torch.stack``,
243+
``.unbind()``, ``.select()``, ``.exclude()``, ``.update()``, and all other
244+
``TensorDict`` methods.
245+
246+
Type checking
247+
-------------
248+
249+
``TypedTensorDict`` uses ``@dataclass_transform()`` (PEP 681) on its metaclass.
250+
This means type checkers (pyright, mypy) understand:
251+
252+
- **Constructor signatures** -- missing or extra fields are flagged.
253+
- **Attribute access** -- ``state.eta`` is typed as ``Tensor``, and typos like
254+
``state.etta`` produce errors.
255+
- **Inheritance** -- subclass fields include parent fields.
256+
257+
String-key access (``state["eta"]``) works at runtime but does not get type
258+
narrowing without a dedicated type checker plugin. For typed access, prefer
259+
dot notation (``state.eta``).
260+
261+
.. autosummary::
262+
:toctree: generated/
263+
:template: td_template.rst
264+
265+
TypedTensorDict

tensordict/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
tensorclass,
5757
TensorClass,
5858
)
59+
from tensordict.typedtensordict import TypedTensorDict
5960
from tensordict.utils import (
6061
assert_allclose_td,
6162
assert_close,
@@ -161,7 +162,6 @@
161162
"set_lazy_legacy",
162163
"list_to_stack",
163164
"set_list_to_stack",
164-
"set_printoptions",
165165
"get_printoptions",
166166
# TensorClass components
167167
"tensorclass",

tensordict/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ from tensordict.tensorclass import (
5555
TensorClass,
5656
tensorclass,
5757
)
58+
from tensordict.typedtensordict import TypedTensorDict
5859
from tensordict.utils import (
5960
assert_allclose_td,
6061
assert_close,
@@ -81,6 +82,7 @@ __all__ = [
8182
"LazyStackedTensorDict",
8283
"UnbatchedTensor",
8384
"TensorClass",
85+
"TypedTensorDict",
8486
"MemoryMappedTensor",
8587
"PersistentTensorDict",
8688
"NestedKey",

0 commit comments

Comments
 (0)