|
| 1 | +.. currentmodule:: tensordict |
| 2 | + |
| 3 | +TypedTensorDict |
| 4 | +=============== |
| 5 | + |
| 6 | +:class:`~tensordict.TypedTensorDict` is a :class:`~tensordict.TensorDict` subclass |
| 7 | +with typed field declarations. It brings ``TypedDict``-style class definitions to |
| 8 | +``TensorDict``: you declare fields as class annotations and get typed construction, |
| 9 | +typed attribute access, inheritance, ``NotRequired`` fields, and ``**state`` spreading |
| 10 | +-- all while keeping every ``TensorDict`` operation available. |
| 11 | + |
| 12 | +.. code-block:: python |
| 13 | +
|
| 14 | + >>> import torch |
| 15 | + >>> from tensordict import TypedTensorDict |
| 16 | + >>> from torch import Tensor |
| 17 | + >>> |
| 18 | + >>> class PredictorState(TypedTensorDict): |
| 19 | + ... eta: Tensor |
| 20 | + ... X: Tensor |
| 21 | + ... beta: Tensor |
| 22 | + >>> |
| 23 | + >>> state = PredictorState( |
| 24 | + ... eta=torch.randn(5, 3), |
| 25 | + ... X=torch.randn(5, 4), |
| 26 | + ... beta=torch.randn(5, 1), |
| 27 | + ... batch_size=[5], |
| 28 | + ... ) |
| 29 | + >>> state.eta.shape |
| 30 | + torch.Size([5, 3]) |
| 31 | + >>> state["X"].shape |
| 32 | + torch.Size([5, 4]) |
| 33 | +
|
| 34 | +Why TypedTensorDict? |
| 35 | +-------------------- |
| 36 | + |
| 37 | +Typed pipelines often build up state one step at a time: |
| 38 | + |
| 39 | +.. code-block:: python |
| 40 | +
|
| 41 | + class PredictorState(TypedTensorDict): |
| 42 | + eta: Tensor |
| 43 | + X: Tensor |
| 44 | + beta: Tensor |
| 45 | +
|
| 46 | + class ObservedState(PredictorState): |
| 47 | + y: Tensor |
| 48 | + mu: Tensor |
| 49 | +
|
| 50 | + def gaussian(state: PredictorState, std: float) -> ObservedState: |
| 51 | + eta = state.eta |
| 52 | + y = eta + torch.randn_like(eta) * std |
| 53 | + return ObservedState(**state, y=y, mu=eta, batch_size=state.batch_size) |
| 54 | +
|
| 55 | +Each stage inherits the previous one's fields and adds new ones. The |
| 56 | +``**state`` spreading pattern lets transition functions stay short regardless |
| 57 | +of how many fields the state has. And because ``TypedTensorDict`` **is** a |
| 58 | +``TensorDict``, every operation -- ``.to(device)``, ``.clone()``, slicing, |
| 59 | +``torch.stack``, ``memmap`` -- works at every stage. |
| 60 | + |
| 61 | +TypedTensorDict vs TensorClass |
| 62 | +------------------------------ |
| 63 | + |
| 64 | +Both ``TypedTensorDict`` and ``TensorClass`` provide typed tensor containers. |
| 65 | +They share the same class-option syntax (``["shadow"]``, ``["frozen"]``, etc.) |
| 66 | +and both use ``@dataclass_transform()`` for IDE support. The key difference is |
| 67 | +in the underlying model: |
| 68 | + |
| 69 | +.. list-table:: |
| 70 | + :header-rows: 1 |
| 71 | + :widths: 35 30 30 |
| 72 | + |
| 73 | + * - Feature |
| 74 | + - ``TypedTensorDict`` |
| 75 | + - ``TensorClass`` |
| 76 | + * - Inherits from |
| 77 | + - ``TensorDict`` directly |
| 78 | + - ``TensorCollection`` (wraps a ``TensorDict`` internally) |
| 79 | + * - Inheritance |
| 80 | + - Standard Python (``class Child(Parent): ...``) |
| 81 | + - Supported via metaclass |
| 82 | + * - ``**state`` spreading |
| 83 | + - Works natively (``MutableMapping``) |
| 84 | + - Requires manual field-by-field repacking |
| 85 | + * - ``state["key"]`` |
| 86 | + - Works natively (``TensorDict.__getitem__``) |
| 87 | + - Raises ``ValueError`` -- use ``state.key`` or ``state.get("key")`` |
| 88 | + * - ``NotRequired`` fields |
| 89 | + - Supported |
| 90 | + - Not supported |
| 91 | + * - Non-tensor fields |
| 92 | + - Not supported (tensor-only) |
| 93 | + - Supported (strings, ints, arbitrary objects) |
| 94 | + * - Custom methods |
| 95 | + - Supported (regular class methods) |
| 96 | + - Supported (regular class methods) |
| 97 | + * - ``@tensorclass`` decorator |
| 98 | + - Not needed (uses metaclass via inheritance) |
| 99 | + - Required (or ``class Foo(TensorClass): ...``) |
| 100 | + |
| 101 | +**When to use which:** |
| 102 | + |
| 103 | +- Use ``TypedTensorDict`` when you have a typed pipeline with progressive state |
| 104 | + accumulation, need ``**state`` spreading, or want direct ``TensorDict`` |
| 105 | + interop without a wrapper layer. |
| 106 | + |
| 107 | +- Use ``TensorClass`` when you need non-tensor fields (strings, metadata), |
| 108 | + custom ``__init__`` logic, or your codebase already uses ``@tensorclass`` |
| 109 | + extensively. |
| 110 | + |
| 111 | +Inheritance and field accumulation |
| 112 | +---------------------------------- |
| 113 | + |
| 114 | +Fields accumulate through the MRO. Each subclass adds its own fields while |
| 115 | +inheriting all parent fields: |
| 116 | + |
| 117 | +.. code-block:: python |
| 118 | +
|
| 119 | + >>> from typing import NotRequired |
| 120 | + >>> |
| 121 | + >>> class PredictorState(TypedTensorDict): |
| 122 | + ... eta: Tensor |
| 123 | + ... X: Tensor |
| 124 | + ... beta: Tensor |
| 125 | + >>> |
| 126 | + >>> class ObservedState(PredictorState): |
| 127 | + ... y: Tensor |
| 128 | + ... mu: Tensor |
| 129 | + ... noise: NotRequired[Tensor] |
| 130 | + >>> |
| 131 | + >>> class SurvivalState(ObservedState): |
| 132 | + ... event_time: Tensor |
| 133 | + ... indicator: Tensor |
| 134 | + ... observed_time: Tensor |
| 135 | +
|
| 136 | + >>> ObservedState.__required_keys__ |
| 137 | + frozenset({'eta', 'X', 'beta', 'y', 'mu'}) |
| 138 | + >>> ObservedState.__optional_keys__ |
| 139 | + frozenset({'noise'}) |
| 140 | +
|
| 141 | +Inheritance works as standard Python: ``isinstance(obs, PredictorState)`` |
| 142 | +returns ``True`` for an ``ObservedState`` instance, and a function typed as |
| 143 | +``f(state: PredictorState)`` accepts any subclass. |
| 144 | + |
| 145 | +NotRequired fields |
| 146 | +------------------ |
| 147 | + |
| 148 | +Mark fields as optional with :data:`~typing.NotRequired`: |
| 149 | + |
| 150 | +.. code-block:: python |
| 151 | +
|
| 152 | + >>> from typing import NotRequired |
| 153 | + >>> |
| 154 | + >>> class ObservedState(PredictorState): |
| 155 | + ... y: Tensor |
| 156 | + ... mu: Tensor |
| 157 | + ... noise: NotRequired[Tensor] |
| 158 | +
|
| 159 | + >>> obs = ObservedState( |
| 160 | + ... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1), |
| 161 | + ... y=torch.randn(5, 3), mu=torch.randn(5, 3), |
| 162 | + ... batch_size=[5], |
| 163 | + ... ) |
| 164 | + >>> "noise" in obs |
| 165 | + False |
| 166 | +
|
| 167 | +If a ``NotRequired`` field is not provided, it is simply absent from the |
| 168 | +underlying ``TensorDict``. Accessing it via attribute raises |
| 169 | +``AttributeError``. |
| 170 | + |
| 171 | +Spreading (``**state``) |
| 172 | +----------------------- |
| 173 | + |
| 174 | +Because ``TypedTensorDict`` is a ``MutableMapping``, the ``**`` operator |
| 175 | +unpacks it into keyword arguments. This makes state transitions concise: |
| 176 | + |
| 177 | +.. code-block:: python |
| 178 | +
|
| 179 | + >>> state = PredictorState( |
| 180 | + ... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1), |
| 181 | + ... batch_size=[5], |
| 182 | + ... ) |
| 183 | + >>> obs = ObservedState( |
| 184 | + ... **state, |
| 185 | + ... y=torch.randn(5, 3), |
| 186 | + ... mu=torch.randn(5, 3), |
| 187 | + ... batch_size=state.batch_size, |
| 188 | + ... ) |
| 189 | + >>> set(obs.keys()) == {"eta", "X", "beta", "y", "mu"} |
| 190 | + True |
| 191 | +
|
| 192 | +Adding a new field to a pipeline stage is one line in the class definition -- |
| 193 | +no transition function needs updating. |
| 194 | + |
| 195 | +Class options |
| 196 | +------------- |
| 197 | + |
| 198 | +``TypedTensorDict`` supports the same bracket-syntax options as ``TensorClass``: |
| 199 | + |
| 200 | +.. code-block:: python |
| 201 | +
|
| 202 | + class MyModel(TypedTensorDict["shadow"]): |
| 203 | + data: Tensor # "data" shadows TensorDict.data -- allowed |
| 204 | +
|
| 205 | + class Immutable(TypedTensorDict["frozen"]): |
| 206 | + x: Tensor # locked after construction |
| 207 | +
|
| 208 | + class Combined(TypedTensorDict["shadow", "frozen"]): |
| 209 | + data: Tensor |
| 210 | +
|
| 211 | +- ``"shadow"`` -- Allow field names that clash with ``TensorDict`` attributes. |
| 212 | + Without this, conflicting names raise ``AttributeError`` at class definition |
| 213 | + time. |
| 214 | +- ``"frozen"`` -- Lock the ``TensorDict`` after construction (read-only). |
| 215 | +- ``"autocast"`` -- Automatically cast assigned values. |
| 216 | +- ``"nocast"`` -- Disable type casting on assignment. |
| 217 | +- ``"tensor_only"`` -- Restrict fields to tensor types only. |
| 218 | + |
| 219 | +Options propagate through inheritance: a subclass of a ``"frozen"`` class is |
| 220 | +also frozen. |
| 221 | + |
| 222 | +TensorDict operations |
| 223 | +--------------------- |
| 224 | + |
| 225 | +Every ``TensorDict`` operation works on ``TypedTensorDict`` instances: |
| 226 | + |
| 227 | +.. code-block:: python |
| 228 | +
|
| 229 | + >>> state = PredictorState( |
| 230 | + ... eta=torch.randn(5, 3), X=torch.randn(5, 4), beta=torch.randn(5, 1), |
| 231 | + ... batch_size=[5], |
| 232 | + ... ) |
| 233 | + >>> state.to("cpu").device |
| 234 | + device(type='cpu') |
| 235 | + >>> state.clone()["eta"].shape |
| 236 | + torch.Size([5, 3]) |
| 237 | + >>> state[0:3].batch_size |
| 238 | + torch.Size([3]) |
| 239 | + >>> torch.stack([state, state], dim=0).batch_size |
| 240 | + torch.Size([2, 5]) |
| 241 | +
|
| 242 | +This includes ``.memmap()``, ``.apply()``, ``torch.cat``, ``torch.stack``, |
| 243 | +``.unbind()``, ``.select()``, ``.exclude()``, ``.update()``, and all other |
| 244 | +``TensorDict`` methods. |
| 245 | + |
| 246 | +Type checking |
| 247 | +------------- |
| 248 | + |
| 249 | +``TypedTensorDict`` uses ``@dataclass_transform()`` (PEP 681) on its metaclass. |
| 250 | +This means type checkers (pyright, mypy) understand: |
| 251 | + |
| 252 | +- **Constructor signatures** -- missing or extra fields are flagged. |
| 253 | +- **Attribute access** -- ``state.eta`` is typed as ``Tensor``, and typos like |
| 254 | + ``state.etta`` produce errors. |
| 255 | +- **Inheritance** -- subclass fields include parent fields. |
| 256 | + |
| 257 | +String-key access (``state["eta"]``) works at runtime but does not get type |
| 258 | +narrowing without a dedicated type checker plugin. For typed access, prefer |
| 259 | +dot notation (``state.eta``). |
| 260 | + |
| 261 | +.. autosummary:: |
| 262 | + :toctree: generated/ |
| 263 | + :template: td_template.rst |
| 264 | + |
| 265 | + TypedTensorDict |
0 commit comments