Skip to content

Commit 173c2bd

Browse files
authored
[Refactor] TypedTensorDict redesign (#1662)
1 parent d8f32e4 commit 173c2bd

File tree

9 files changed

+1324
-137
lines changed

9 files changed

+1324
-137
lines changed

docs/source/compatibility.rst

Lines changed: 159 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Architecture overview
1515
TensorCollection
1616
├── TensorDictBase
1717
│ ├── TensorDict (in-memory)
18-
│ └── TypedTensorDict (typed fields, IS-A TensorDict)
18+
── TypedTensorDict (typed fields, wraps any TensorDictBase)
1919
│ ├── PersistentTensorDict (HDF5-backed)
2020
│ ├── TensorDictStore (Redis / Dragonfly / KeyDB)
2121
│ └── LazyStackedTensorDict (lazy stack of heterogeneous TDs)
@@ -26,8 +26,11 @@ Two patterns exist for adding typed field declarations:
2626

2727
- **TensorClass** wraps any ``TensorDictBase`` via ``from_tensordict(td)``.
2828
It delegates all storage to the wrapped object.
29-
- **TypedTensorDict** *is* a ``TensorDict``. It stores data in-memory and
30-
interoperates with other backends through conversion or stacking.
29+
- **TypedTensorDict** wraps any ``TensorDictBase`` via ``from_tensordict(td)``,
30+
similar to ``TensorClass``. Direct construction creates a ``TensorDict``
31+
internally. Unlike ``TensorClass``, it inherits from ``TensorDictBase``
32+
directly, supports ``**state`` spreading natively, and uses standard
33+
Python inheritance for schema composition.
3134

3235
TensorClass + backends
3336
----------------------
@@ -205,109 +208,162 @@ enforce schemas, but they compose without conflict:
205208
TypedTensorDict + backends
206209
--------------------------
207210

208-
``TypedTensorDict`` is a ``TensorDict`` subclass. It stores data in-memory
209-
but interoperates with other backends through conversion or stacking.
211+
``TypedTensorDict.from_tensordict(td)`` accepts any ``TensorDictBase`` subclass,
212+
just like ``TensorClass``. The backend is stored live (no copy) -- mutations
213+
through the ``TypedTensorDict`` go directly to the underlying backend.
214+
215+
.. code-block:: python
216+
217+
from tensordict import TypedTensorDict
218+
from torch import Tensor
219+
220+
class State(TypedTensorDict):
221+
x: Tensor
222+
y: Tensor
223+
224+
state = State.from_tensordict(some_backend)
210225
211226
.. list-table::
212227
:header-rows: 1
213-
:widths: 30 12 12 12 12 12
228+
:widths: 22 10 10 10 10 10 10 10 10
214229

215-
* - Pattern
230+
* - Backend
216231
- Build
217232
- Read
218233
- Write
219234
- Index
235+
- Clone
220236
- Stack
221-
* - Direct construction
222-
- yes
237+
- Iter
238+
- Update
239+
* - ``TensorDict``
223240
- yes
224241
- yes
225242
- yes
226243
- yes
227-
* - From H5 (materialise then construct)
228244
- yes
229245
- yes
230246
- yes
231247
- yes
248+
* - ``PersistentTensorDict`` (H5)
232249
- yes
233-
* - From Redis (materialise then construct)
234250
- yes
235251
- yes
236252
- yes
237253
- yes
238254
- yes
239-
* - From lazy stack (materialise then construct)
240255
- yes
241256
- yes
257+
* - ``TensorDictStore`` (Redis)
242258
- yes
243259
- yes
244260
- yes
245-
* - ``torch.stack`` (dense)
246261
- yes
247262
- yes
248263
- yes
249264
- yes
250-
- --
251-
* - ``LazyStackedTensorDict`` of TTDs
252265
- yes
266+
* - ``LazyStackedTensorDict``
253267
- yes
254268
- yes
255269
- yes
256-
- --
257-
* - ``memmap_()``
258270
- yes
259271
- yes
260-
- set\_()
261272
- yes
262273
- yes
263-
* - To H5 (``PersistentTensorDict.from_dict``)
264274
- yes
275+
* - ``TensorDict`` (memmap)
265276
- yes
266-
- H5 rules
267277
- yes
268-
- --
269-
* - To Redis (``TensorDictStore.from_tensordict``)
278+
- set\_()
270279
- yes
271280
- yes
272281
- yes
273282
- yes
274-
- --
283+
- update\_()
275284

276-
Constructing TypedTensorDict from other backends
277-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
285+
.. note::
278286

279-
Since ``TypedTensorDict`` is an in-memory ``TensorDict``, loading data from a
280-
remote or persistent backend requires materialising the data first:
287+
Memory-mapped TensorDicts are locked after ``memmap_()``. Use
288+
``set_()`` and ``update_()`` for in-place writes instead of attribute
289+
assignment or ``update()``.
290+
291+
Building a TypedTensorDict on each backend
292+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
293+
294+
**In-memory TensorDict** -- the default (direct construction creates one
295+
internally):
281296

282297
.. code-block:: python
283298
284299
>>> import torch
285-
>>> from tensordict import TypedTensorDict
300+
>>> from tensordict import TensorDict, TypedTensorDict
286301
>>> from torch import Tensor
287302
>>>
288303
>>> class State(TypedTensorDict):
289304
... x: Tensor
290305
... y: Tensor
306+
>>>
307+
>>> state = State(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
308+
>>> state.x.shape
309+
torch.Size([4, 3])
291310
292-
**From HDF5**:
311+
**Wrapping an existing TensorDict** via ``from_tensordict`` (zero-copy):
312+
313+
.. code-block:: python
314+
315+
>>> td = TensorDict(x=torch.randn(4, 3), y=torch.randn(4, 5), batch_size=[4])
316+
>>> state = State.from_tensordict(td)
317+
>>> state.x.shape # reads from td
318+
torch.Size([4, 3])
319+
>>> state.x = torch.ones(4, 3) # writes to td
320+
>>> (td["x"] == 1).all()
321+
True
322+
323+
**HDF5 (PersistentTensorDict)**:
293324

294325
.. code-block:: python
295326
296327
>>> from tensordict import PersistentTensorDict
297328
>>>
298329
>>> h5 = PersistentTensorDict.from_h5("data.h5")
299-
>>> local = h5.to_tensordict()
300-
>>> state = State(x=local["x"], y=local["y"], batch_size=local.batch_size)
330+
>>> state = State.from_tensordict(h5)
331+
>>> state.x.shape # reads from HDF5
332+
torch.Size([4, 3])
301333
302-
**From a lazy stack**:
334+
**Redis (TensorDictStore)**:
335+
336+
.. code-block:: python
337+
338+
>>> from tensordict.store import TensorDictStore
339+
>>>
340+
>>> store = TensorDictStore.from_tensordict(td, host="localhost")
341+
>>> state = State.from_tensordict(store)
342+
>>> state.x.shape # fetched from Redis
343+
torch.Size([4, 3])
344+
345+
**Lazy stack**:
303346

304347
.. code-block:: python
305348
306349
>>> from tensordict import lazy_stack
307350
>>>
308-
>>> ls = lazy_stack([td1, td2], dim=0)
309-
>>> local = ls.to_tensordict()
310-
>>> state = State(x=local["x"], y=local["y"], batch_size=local.batch_size)
351+
>>> tds = [TensorDict(x=torch.randn(3), y=torch.randn(5)) for _ in range(4)]
352+
>>> ls = lazy_stack(tds, dim=0)
353+
>>> state = State.from_tensordict(ls)
354+
>>> state[0].x.shape
355+
torch.Size([3])
356+
357+
**Memory-mapped TensorDict**:
358+
359+
.. code-block:: python
360+
361+
>>> td_mmap = td.memmap_("/tmp/my_memmap")
362+
>>> state = State.from_tensordict(td_mmap)
363+
>>> state.x.shape
364+
torch.Size([4, 3])
365+
>>> # memmap TDs are locked -- use in-place operations:
366+
>>> state.set_("x", torch.ones(4, 3))
311367
312368
Stacking TypedTensorDicts
313369
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -334,30 +390,69 @@ Lazy stacking also works. Indexing a ``LazyStackedTensorDict`` of
334390
>>> isinstance(ls[0], State)
335391
True
336392
337-
Saving TypedTensorDict to persistent backends
338-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
393+
.. _compat-redis-prealloc:
394+
395+
Pre-allocating on Redis and filling iteratively
396+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
339397

340-
Since ``TypedTensorDict`` is a ``TensorDict``, it can be saved to HDF5, Redis,
341-
or memory-mapped storage directly:
398+
A common pattern for shared replay buffers or distributed data stores is to
399+
pre-allocate storage on a remote server (Redis / Dragonfly / KeyDB) and fill
400+
it one sample at a time, without ever loading the full dataset into RAM.
401+
402+
``TensorDictStore.from_schema`` creates keys with known shapes and dtypes
403+
directly on the server using ``SETRANGE`` (zero-filled by the server; no
404+
tensor data passes through Python):
342405

343406
.. code-block:: python
344407
345-
>>> # To HDF5
346-
>>> from tensordict import PersistentTensorDict
347-
>>> h5 = PersistentTensorDict.from_dict(state, filename="state.h5")
408+
>>> import torch
409+
>>> from tensordict import TensorDict, TypedTensorDict
410+
>>> from tensordict.store import TensorDictStore
411+
>>> from torch import Tensor
348412
>>>
349-
>>> # To memmap
350-
>>> state.memmap_("/tmp/state_mmap")
413+
>>> class Replay(TypedTensorDict):
414+
... obs: Tensor
415+
... action: Tensor
416+
... reward: Tensor
351417
>>>
352-
>>> # To Redis
353-
>>> from tensordict.store import TensorDictStore
354-
>>> store = TensorDictStore.from_tensordict(state, host="localhost")
418+
>>> # Pre-allocate 100k entries directly on Redis -- no RAM used
419+
>>> store = TensorDictStore.from_schema(
420+
... {"obs": ([84, 84, 3], torch.uint8),
421+
... "action": ([4], torch.float32),
422+
... "reward": ([], torch.float32)},
423+
... batch_size=[100_000],
424+
... host="redis-node",
425+
... )
426+
>>>
427+
>>> # Wrap with typed access
428+
>>> replay = Replay.from_tensordict(store)
429+
>>>
430+
>>> # Fill iteratively -- each write goes directly to Redis
431+
>>> for i, sample in enumerate(data_stream):
432+
... replay[i] = Replay(
433+
... obs=sample.obs, action=sample.action, reward=sample.reward,
434+
... batch_size=[],
435+
... )
436+
437+
If the store is initially empty (no keys registered yet), use ``check=False``
438+
to skip the key-presence validation and fill keys on the fly:
439+
440+
.. code-block:: python
441+
442+
>>> store = TensorDictStore(batch_size=[100_000], host="redis-node")
443+
>>> replay = Replay.from_tensordict(store, check=False)
444+
>>>
445+
>>> # First indexed write auto-creates each key via SETRANGE
446+
>>> replay[0] = Replay(obs=obs_0, action=act_0, reward=r_0, batch_size=[])
447+
>>> # Subsequent writes fill in the pre-allocated storage
448+
>>> replay[1] = Replay(obs=obs_1, action=act_1, reward=r_1, batch_size=[])
355449
356450
357451
TensorClass vs TypedTensorDict
358452
------------------------------
359453

360-
Both enforce typed schemas but differ architecturally:
454+
Both enforce typed schemas and can wrap any ``TensorDictBase`` backend, but
455+
they differ architecturally:
361456

362457
.. list-table::
363458
:header-rows: 1
@@ -366,12 +461,12 @@ Both enforce typed schemas but differ architecturally:
366461
* - Aspect
367462
- ``TensorClass``
368463
- ``TypedTensorDict``
369-
* - Relationship to ``TensorDict``
370-
- Wraps a ``TensorDictBase`` (HAS-A)
371-
- Is a ``TensorDict`` (IS-A)
464+
* - Relationship to ``TensorDictBase``
465+
- Wraps a ``TensorDictBase`` (HAS-A via ``TensorCollection``)
466+
- Is a ``TensorDictBase`` (IS-A, delegates to ``_source``)
372467
* - Can wrap non-TensorDict backends
373468
- Yes (H5, Redis, lazy stack, etc.)
374-
- No (in-memory only; convert first)
469+
- Yes (H5, Redis, lazy stack, etc.)
375470
* - ``**state`` spreading
376471
- Field-by-field repacking
377472
- Natively (``MutableMapping``)
@@ -380,15 +475,19 @@ Both enforce typed schemas but differ architecturally:
380475
- Not supported (tensor-only)
381476
* - Backend stays live
382477
- Yes (writes go to original backend)
383-
- No (data is in-memory after construction)
478+
- Yes (writes go to original backend)
479+
* - Python inheritance
480+
- Not supported
481+
- Supported (standard class hierarchy)
384482
* - Composable with each other
385483
- Yes (``TC.from_tensordict(ttd)`` works)
386-
- N/A
484+
- Yes (``TTD.from_tensordict(tc._tensordict)`` works)
387485

388-
When a ``TensorClass`` wraps a persistent backend (H5, Redis), writes through
389-
the ``TensorClass`` go directly to that backend. When a ``TypedTensorDict`` is
390-
constructed from persistent data, the data is copied into memory.
486+
Both wrappers keep the backend alive -- mutations through the typed wrapper go
487+
directly to the underlying storage. Direct construction (without
488+
``from_tensordict``) creates an in-memory ``TensorDict`` as the backend.
391489

392-
Choose ``TensorClass`` when you need live access to a remote or on-disk backend
393-
with typed field access. Choose ``TypedTensorDict`` when you want typed,
394-
in-memory state with ``**state`` spreading and standard Python inheritance.
490+
Choose ``TensorClass`` when you need non-tensor fields or want to integrate
491+
with existing tensorclass-based APIs. Choose ``TypedTensorDict`` when you
492+
want native ``**state`` spreading, standard Python inheritance for schema
493+
composition, and full ``TensorDictBase`` API compatibility.

0 commit comments

Comments
 (0)