Skip to content

Commit a1bc21f

Browse files
authored
[Doc] Link README features to tutorials, add functional programming tutorial (#1615)
1 parent 2693086 commit a1bc21f

3 files changed

Lines changed: 122 additions & 5 deletions

File tree

README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
[![Docs](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=tensordict)][#docs-package]
2-
[![Discord](https://dcbadge.vercel.app/api/server/tz3TgTAe3D)](https://discord.gg/tz3TgTAe3D)
2+
[![Discord](https://img.shields.io/badge/Discord-blue?logo=discord&logoColor=white)](https://discord.gg/tz3TgTAe3D)
33
[![Python version](https://img.shields.io/pypi/pyversions/tensordict.svg)](https://www.python.org/downloads/)
44
[![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)][#github-license]
55
<a href="https://pypi.org/project/tensordict"><img src="https://img.shields.io/pypi/v/tensordict" alt="pypi version"></a>
66
[![Downloads](https://static.pepy.tech/personalized-badge/tensordict?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)][#pepy-package]
77
[![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/tensordict?logo=anaconda&style=flat&color=orange)][#conda-forge-package]
88

9-
[#docs-package]: https://pytorch.github.io/tensordict/
10-
[#docs-package-benchmark]: https://pytorch.github.io/tensordict/dev/bench/
9+
[#docs-package]: https://docs.pytorch.org/tensordict/stable/
10+
[#docs-package-benchmark]: https://docs.pytorch.org/tensordict/stable/dev/bench/
1111
[#github-license]: https://github.com/pytorch/tensordict/blob/main/LICENSE
1212
[#pepy-package]: https://pepy.tech/project/tensordict
1313
[#conda-forge-package]: https://anaconda.org/conda-forge/tensordict
@@ -16,8 +16,8 @@
1616

1717
TensorDict is a dictionary-like class that inherits properties from tensors,
1818
such as indexing, shape operations, casting to device or storage and many more.
19-
The code-base consists of two main components: [`TensorDict`](https://pytorch.github.io/tensordict/reference/generated/tensordict.TensorDict.html),
20-
a specialized dictionary for PyTorch tensors, and [`tensorclass`](https://pytorch.github.io/tensordict/reference/generated/tensordict.tensorclass.html),
19+
The code-base consists of two main components: [`TensorDict`](https://docs.pytorch.org/tensordict/stable/reference/generated/tensordict.TensorDict.html),
20+
a specialized dictionary for PyTorch tensors, and [`tensorclass`](https://docs.pytorch.org/tensordict/stable/reference/generated/tensordict.tensorclass.html),
2121
a dataclass for tensors.
2222

2323
```python
@@ -48,14 +48,23 @@ TensorDict makes your code-bases more _readable_, _compact_, _modular_ and _fast
4848
It abstracts away tailored operations, dispatching them on the leaves for you.
4949

5050
- **Composability**: `TensorDict` generalizes `torch.Tensor` operations to collections of tensors.
51+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensordict_shapes.html)
5152
- **Speed**: asynchronous transfer to device, fast node-to-node communication through `consolidate`, compatible with `torch.compile`.
53+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensordict_memory.html)
5254
- **Shape operations**: indexing, slicing, concatenation, reshaping -- everything you can do with a tensor.
55+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensordict_slicing.html)
5356
- **Distributed / multiprocessed**: distribute TensorDict instances across workers, devices and machines.
57+
[[doc]](https://docs.pytorch.org/tensordict/stable/distributed.html)
5458
- **Serialization** and memory-mapping for efficient checkpointing.
59+
[[doc]](https://docs.pytorch.org/tensordict/stable/saving.html)
5560
- **Functional programming** and compatibility with `torch.vmap`.
61+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/functional.html)
5662
- **Nesting**: nest TensorDict instances to create hierarchical structures.
63+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensordict_keys.html)
5764
- **Lazy preallocation**: preallocate memory without initializing tensors.
65+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensordict_preallocation.html)
5866
- **`@tensorclass`**: a specialized dataclass for `torch.Tensor`.
67+
[[tutorial]](https://docs.pytorch.org/tensordict/stable/tutorials/tensorclass_fashion.html)
5968

6069
## Examples
6170

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ tensordict.nn
8585
:maxdepth: 1
8686

8787
tutorials/tensordict_module
88+
tutorials/functional
8889
tutorials/export
8990

9091
Dataloading
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
Functional Programming with TensorDict
3+
=======================================
4+
**Author**: `Vincent Moens <https://github.com/vmoens>`_
5+
6+
In this tutorial you will learn how to use :class:`~.TensorDict` for
7+
functional-style programming with :class:`~torch.nn.Module`, including
8+
parameter swapping, model ensembling with :func:`~torch.vmap`, and
9+
functional calls with :func:`~torch.func.functional_call`.
10+
"""
11+
12+
##############################################################################
13+
# TensorDict as a parameter container
14+
# ------------------------------------
15+
#
16+
# :meth:`~.TensorDict.from_module` extracts the parameters of a module into a
17+
# nested :class:`~.TensorDict` whose structure mirrors the module hierarchy.
18+
19+
# sphinx_gallery_start_ignore
20+
import warnings
21+
22+
warnings.filterwarnings("ignore")
23+
# sphinx_gallery_end_ignore
24+
import torch
25+
import torch.nn as nn
26+
from tensordict import TensorDict
27+
28+
module = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 1))
29+
params = TensorDict.from_module(module)
30+
print(params)
31+
32+
##############################################################################
33+
# The resulting :class:`~.TensorDict` holds the same
34+
# :class:`~torch.nn.Parameter` objects as the module. We can manipulate them
35+
# as a batch -- for example, zeroing all parameters at once:
36+
37+
params_zero = params.detach().clone().zero_()
38+
print("All zeros:", (params_zero == 0).all())
39+
40+
##############################################################################
41+
# Swapping parameters with a context manager
42+
# -------------------------------------------
43+
#
44+
# :meth:`~.TensorDict.to_module` temporarily replaces the parameters of a
45+
# module within a context manager. The original parameters are restored on
46+
# exit.
47+
48+
x = torch.randn(5, 3)
49+
50+
with params_zero.to_module(module):
51+
y_zero = module(x)
52+
53+
print("Output with zeroed params:", y_zero)
54+
assert (y_zero == 0).all()
55+
56+
y_original = module(x)
57+
print("Output with original params:", y_original)
58+
assert not (y_original == 0).all()
59+
60+
##############################################################################
61+
# Model ensembling with ``torch.vmap``
62+
# -------------------------------------
63+
#
64+
# Because :class:`~.TensorDict` supports batching and stacking, we can stack
65+
# multiple parameter configurations and use :func:`~torch.vmap` to run the
66+
# model across all of them in a single vectorized call.
67+
68+
params_ones = params.detach().clone().apply_(lambda t: t.fill_(1.0))
69+
params_stack = torch.stack([params_zero, params_ones, params])
70+
71+
print("Stacked params batch_size:", params_stack.batch_size)
72+
73+
74+
def call(x, td):
75+
with td.to_module(module):
76+
return module(x)
77+
78+
79+
x = torch.randn(3, 5, 3)
80+
y = torch.vmap(call)(x, params_stack)
81+
print("Output shape:", y.shape)
82+
83+
assert (y[0] == 0).all()
84+
85+
##############################################################################
86+
# Functional calls with ``torch.func``
87+
# --------------------------------------
88+
#
89+
# :func:`~torch.func.functional_call` works with the state-dict extracted
90+
# by :meth:`~.TensorDict.from_module`. Because ``from_module`` returns a
91+
# :class:`~.TensorDict` with the same structure as a state-dict, we can
92+
# convert it to a regular dict and pass it directly.
93+
94+
from torch.func import functional_call
95+
96+
flat_params = params.flatten_keys(".")
97+
state_dict = dict(flat_params.items())
98+
x = torch.randn(5, 3)
99+
y = functional_call(module, state_dict, x)
100+
print("functional_call output:", y.shape)
101+
102+
##############################################################################
103+
# The combination of :meth:`~.TensorDict.from_module`,
104+
# :meth:`~.TensorDict.to_module`, and :func:`~torch.vmap` makes it
105+
# straightforward to do things like compute per-sample gradients, run
106+
# model ensembles, or implement meta-learning inner loops -- all without
107+
# leaving the standard PyTorch API.

0 commit comments

Comments
 (0)