@@ -15,7 +15,7 @@ Architecture overview
1515 TensorCollection
1616 ├── TensorDictBase
1717 │ ├── TensorDict (in-memory)
18- │ │ └ ── TypedTensorDict (typed fields, IS-A TensorDict )
18+ │ ├ ── TypedTensorDict (typed fields, wraps any TensorDictBase )
1919 │ ├── PersistentTensorDict (HDF5-backed)
2020 │ ├── TensorDictStore (Redis / Dragonfly / KeyDB)
2121 │ └── LazyStackedTensorDict (lazy stack of heterogeneous TDs)
@@ -26,8 +26,11 @@ Two patterns exist for adding typed field declarations:
2626
2727- **TensorClass ** wraps any ``TensorDictBase `` via ``from_tensordict(td) ``.
2828 It delegates all storage to the wrapped object.
29- - **TypedTensorDict ** *is * a ``TensorDict ``. It stores data in-memory and
30- interoperates with other backends through conversion or stacking.
29+ - **TypedTensorDict ** wraps any ``TensorDictBase `` via ``from_tensordict(td) ``,
30+ similar to ``TensorClass ``. Direct construction creates a ``TensorDict ``
31+ internally. Unlike ``TensorClass ``, it inherits from ``TensorDictBase ``
32+ directly, supports ``**state `` spreading natively, and uses standard
33+ Python inheritance for schema composition.
3134
3235TensorClass + backends
3336----------------------
@@ -205,109 +208,162 @@ enforce schemas, but they compose without conflict:
205208 TypedTensorDict + backends
206209--------------------------
207210
208- ``TypedTensorDict `` is a ``TensorDict `` subclass. It stores data in-memory
209- but interoperates with other backends through conversion or stacking.
211+ ``TypedTensorDict.from_tensordict(td) `` accepts any ``TensorDictBase `` subclass,
212+ just like ``TensorClass ``. The backend is stored live (no copy) -- mutations
213+ through the ``TypedTensorDict `` go directly to the underlying backend.
214+
215+ .. code-block :: python
216+
217+ from tensordict import TypedTensorDict
218+ from torch import Tensor
219+
220+ class State (TypedTensorDict ):
221+ x: Tensor
222+ y: Tensor
223+
224+ state = State.from_tensordict(some_backend)
210225
211226 .. list-table ::
212227 :header-rows: 1
213- :widths: 30 12 12 12 12 12
228+ :widths: 22 10 10 10 10 10 10 10 10
214229
215- * - Pattern
230+ * - Backend
216231 - Build
217232 - Read
218233 - Write
219234 - Index
235+ - Clone
220236 - Stack
221- * - Direct construction
222- - yes
237+ - Iter
238+ - Update
239+ * - ``TensorDict ``
223240 - yes
224241 - yes
225242 - yes
226243 - yes
227- * - From H5 (materialise then construct)
228244 - yes
229245 - yes
230246 - yes
231247 - yes
248+ * - ``PersistentTensorDict `` (H5)
232249 - yes
233- * - From Redis (materialise then construct)
234250 - yes
235251 - yes
236252 - yes
237253 - yes
238254 - yes
239- * - From lazy stack (materialise then construct)
240255 - yes
241256 - yes
257+ * - ``TensorDictStore `` (Redis)
242258 - yes
243259 - yes
244260 - yes
245- * - ``torch.stack `` (dense)
246261 - yes
247262 - yes
248263 - yes
249264 - yes
250- - --
251- * - ``LazyStackedTensorDict `` of TTDs
252265 - yes
266+ * - ``LazyStackedTensorDict ``
253267 - yes
254268 - yes
255269 - yes
256- - --
257- * - ``memmap_() ``
258270 - yes
259271 - yes
260- - set\_ ()
261272 - yes
262273 - yes
263- * - To H5 (``PersistentTensorDict.from_dict ``)
264274 - yes
275+ * - ``TensorDict `` (memmap)
265276 - yes
266- - H5 rules
267277 - yes
268- - --
269- * - To Redis (``TensorDictStore.from_tensordict ``)
278+ - set\_ ()
270279 - yes
271280 - yes
272281 - yes
273282 - yes
274- - --
283+ - update \_ ()
275284
276- Constructing TypedTensorDict from other backends
277- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
285+ .. note ::
278286
279- Since ``TypedTensorDict `` is an in-memory ``TensorDict ``, loading data from a
280- remote or persistent backend requires materialising the data first:
287+ Memory-mapped TensorDicts are locked after ``memmap_() ``. Use
288+ ``set_() `` and ``update_() `` for in-place writes instead of attribute
289+ assignment or ``update() ``.
290+
291+ Building a TypedTensorDict on each backend
292+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
293+
294+ **In-memory TensorDict ** -- the default (direct construction creates one
295+ internally):
281296
282297.. code-block :: python
283298
284299 >> > import torch
285- >> > from tensordict import TypedTensorDict
300+ >> > from tensordict import TensorDict, TypedTensorDict
286301 >> > from torch import Tensor
287302 >> >
288303 >> > class State (TypedTensorDict ):
289304 ... x: Tensor
290305 ... y: Tensor
306+ >> >
307+ >> > state = State(x = torch.randn(4 , 3 ), y = torch.randn(4 , 5 ), batch_size = [4 ])
308+ >> > state.x.shape
309+ torch.Size([4 , 3 ])
291310
292- **From HDF5 **:
311+ **Wrapping an existing TensorDict ** via ``from_tensordict `` (zero-copy):
312+
313+ .. code-block :: python
314+
315+ >> > td = TensorDict(x = torch.randn(4 , 3 ), y = torch.randn(4 , 5 ), batch_size = [4 ])
316+ >> > state = State.from_tensordict(td)
317+ >> > state.x.shape # reads from td
318+ torch.Size([4 , 3 ])
319+ >> > state.x = torch.ones(4 , 3 ) # writes to td
320+ >> > (td[" x" ] == 1 ).all()
321+ True
322+
323+ **HDF5 (PersistentTensorDict) **:
293324
294325.. code-block :: python
295326
296327 >> > from tensordict import PersistentTensorDict
297328 >> >
298329 >> > h5 = PersistentTensorDict.from_h5(" data.h5" )
299- >> > local = h5.to_tensordict()
300- >> > state = State(x = local[" x" ], y = local[" y" ], batch_size = local.batch_size)
330+ >> > state = State.from_tensordict(h5)
331+ >> > state.x.shape # reads from HDF5
332+ torch.Size([4 , 3 ])
301333
302- **From a lazy stack **:
334+ **Redis (TensorDictStore) **:
335+
336+ .. code-block :: python
337+
338+ >> > from tensordict.store import TensorDictStore
339+ >> >
340+ >> > store = TensorDictStore.from_tensordict(td, host = " localhost" )
341+ >> > state = State.from_tensordict(store)
342+ >> > state.x.shape # fetched from Redis
343+ torch.Size([4 , 3 ])
344+
345+ **Lazy stack **:
303346
304347.. code-block :: python
305348
306349 >> > from tensordict import lazy_stack
307350 >> >
308- >> > ls = lazy_stack([td1, td2], dim = 0 )
309- >> > local = ls.to_tensordict()
310- >> > state = State(x = local[" x" ], y = local[" y" ], batch_size = local.batch_size)
351+ >> > tds = [TensorDict(x = torch.randn(3 ), y = torch.randn(5 )) for _ in range (4 )]
352+ >> > ls = lazy_stack(tds, dim = 0 )
353+ >> > state = State.from_tensordict(ls)
354+ >> > state[0 ].x.shape
355+ torch.Size([3 ])
356+
357+ **Memory-mapped TensorDict **:
358+
359+ .. code-block :: python
360+
361+ >> > td_mmap = td.memmap_(" /tmp/my_memmap" )
362+ >> > state = State.from_tensordict(td_mmap)
363+ >> > state.x.shape
364+ torch.Size([4 , 3 ])
365+ >> > # memmap TDs are locked -- use in-place operations:
366+ >> > state.set_(" x" , torch.ones(4 , 3 ))
311367
312368 Stacking TypedTensorDicts
313369^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -334,30 +390,69 @@ Lazy stacking also works. Indexing a ``LazyStackedTensorDict`` of
334390 >> > isinstance (ls[0 ], State)
335391 True
336392
337- Saving TypedTensorDict to persistent backends
338- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
393+ .. _compat-redis-prealloc :
394+
395+ Pre-allocating on Redis and filling iteratively
396+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
339397
340- Since ``TypedTensorDict `` is a ``TensorDict ``, it can be saved to HDF5, Redis,
341- or memory-mapped storage directly:
398+ A common pattern for shared replay buffers or distributed data stores is to
399+ pre-allocate storage on a remote server (Redis / Dragonfly / KeyDB) and fill
400+ it one sample at a time, without ever loading the full dataset into RAM.
401+
402+ ``TensorDictStore.from_schema `` creates keys with known shapes and dtypes
403+ directly on the server using ``SETRANGE `` (zero-filled by the server; no
404+ tensor data passes through Python):
342405
343406.. code-block :: python
344407
345- >> > # To HDF5
346- >> > from tensordict import PersistentTensorDict
347- >> > h5 = PersistentTensorDict.from_dict(state, filename = " state.h5" )
408+ >> > import torch
409+ >> > from tensordict import TensorDict, TypedTensorDict
410+ >> > from tensordict.store import TensorDictStore
411+ >> > from torch import Tensor
348412 >> >
349- >> > # To memmap
350- >> > state.memmap_(" /tmp/state_mmap" )
413+ >> > class Replay (TypedTensorDict ):
414+ ... obs: Tensor
415+ ... action: Tensor
416+ ... reward: Tensor
351417 >> >
352- >> > # To Redis
353- >> > from tensordict.store import TensorDictStore
354- >> > store = TensorDictStore.from_tensordict(state, host = " localhost" )
418+ >> > # Pre-allocate 100k entries directly on Redis -- no RAM used
419+ >> > store = TensorDictStore.from_schema(
420+ ... {" obs" : ([84 , 84 , 3 ], torch.uint8),
421+ ... " action" : ([4 ], torch.float32),
422+ ... " reward" : ([], torch.float32)},
423+ ... batch_size = [100_000 ],
424+ ... host = " redis-node" ,
425+ ... )
426+ >> >
427+ >> > # Wrap with typed access
428+ >> > replay = Replay.from_tensordict(store)
429+ >> >
430+ >> > # Fill iteratively -- each write goes directly to Redis
431+ >> > for i, sample in enumerate (data_stream):
432+ ... replay[i] = Replay(
433+ ... obs = sample.obs, action = sample.action, reward = sample.reward,
434+ ... batch_size = [],
435+ ... )
436+
437+ If the store is initially empty (no keys registered yet), use ``check=False ``
438+ to skip the key-presence validation and fill keys on the fly:
439+
440+ .. code-block :: python
441+
442+ >> > store = TensorDictStore(batch_size = [100_000 ], host = " redis-node" )
443+ >> > replay = Replay.from_tensordict(store, check = False )
444+ >> >
445+ >> > # First indexed write auto-creates each key via SETRANGE
446+ >> > replay[0 ] = Replay(obs = obs_0, action = act_0, reward = r_0, batch_size = [])
447+ >> > # Subsequent writes fill in the pre-allocated storage
448+ >> > replay[1 ] = Replay(obs = obs_1, action = act_1, reward = r_1, batch_size = [])
355449
356450
357451 TensorClass vs TypedTensorDict
358452------------------------------
359453
360- Both enforce typed schemas but differ architecturally:
454+ Both enforce typed schemas and can wrap any ``TensorDictBase `` backend, but
455+ they differ architecturally:
361456
362457.. list-table ::
363458 :header-rows: 1
@@ -366,12 +461,12 @@ Both enforce typed schemas but differ architecturally:
366461 * - Aspect
367462 - ``TensorClass ``
368463 - ``TypedTensorDict ``
369- * - Relationship to ``TensorDict ``
370- - Wraps a ``TensorDictBase `` (HAS-A)
371- - Is a ``TensorDict `` (IS-A)
464+ * - Relationship to ``TensorDictBase ``
465+ - Wraps a ``TensorDictBase `` (HAS-A via `` TensorCollection `` )
466+ - Is a ``TensorDictBase `` (IS-A, delegates to `` _source `` )
372467 * - Can wrap non-TensorDict backends
373468 - Yes (H5, Redis, lazy stack, etc.)
374- - No (in-memory only; convert first )
469+ - Yes (H5, Redis, lazy stack, etc. )
375470 * - ``**state `` spreading
376471 - Field-by-field repacking
377472 - Natively (``MutableMapping ``)
@@ -380,15 +475,19 @@ Both enforce typed schemas but differ architecturally:
380475 - Not supported (tensor-only)
381476 * - Backend stays live
382477 - Yes (writes go to original backend)
383- - No (data is in-memory after construction)
478+ - Yes (writes go to original backend)
479+ * - Python inheritance
480+ - Not supported
481+ - Supported (standard class hierarchy)
384482 * - Composable with each other
385483 - Yes (``TC.from_tensordict(ttd) `` works)
386- - N/A
484+ - Yes (`` TTD.from_tensordict(tc._tensordict) `` works)
387485
388- When a `` TensorClass `` wraps a persistent backend (H5, Redis), writes through
389- the `` TensorClass `` go directly to that backend . When a `` TypedTensorDict `` is
390- constructed from persistent data, the data is copied into memory .
486+ Both wrappers keep the backend alive -- mutations through the typed wrapper go
487+ directly to the underlying storage . Direct construction (without
488+ `` from_tensordict ``) creates an in-memory `` TensorDict `` as the backend .
391489
392- Choose ``TensorClass `` when you need live access to a remote or on-disk backend
393- with typed field access. Choose ``TypedTensorDict `` when you want typed,
394- in-memory state with ``**state `` spreading and standard Python inheritance.
490+ Choose ``TensorClass `` when you need non-tensor fields or want to integrate
491+ with existing tensorclass-based APIs. Choose ``TypedTensorDict `` when you
492+ want native ``**state `` spreading, standard Python inheritance for schema
493+ composition, and full ``TensorDictBase `` API compatibility.
0 commit comments