|
| 1 | +""" |
| 2 | +Functional Programming with TensorDict |
| 3 | +======================================= |
| 4 | +**Author**: `Vincent Moens <https://github.com/vmoens>`_ |
| 5 | +
|
| 6 | +In this tutorial you will learn how to use :class:`~.TensorDict` for |
| 7 | +functional-style programming with :class:`~torch.nn.Module`, including |
| 8 | +parameter swapping, model ensembling with :func:`~torch.vmap`, and |
| 9 | +functional calls with :func:`~torch.func.functional_call`. |
| 10 | +""" |
| 11 | + |
| 12 | +############################################################################## |
| 13 | +# TensorDict as a parameter container |
| 14 | +# ------------------------------------ |
| 15 | +# |
| 16 | +# :meth:`~.TensorDict.from_module` extracts the parameters of a module into a |
| 17 | +# nested :class:`~.TensorDict` whose structure mirrors the module hierarchy. |
| 18 | + |
| 19 | +# sphinx_gallery_start_ignore |
| 20 | +import warnings |
| 21 | + |
| 22 | +warnings.filterwarnings("ignore") |
| 23 | +# sphinx_gallery_end_ignore |
| 24 | +import torch |
| 25 | +import torch.nn as nn |
| 26 | +from tensordict import TensorDict |
| 27 | + |
| 28 | +module = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 1)) |
| 29 | +params = TensorDict.from_module(module) |
| 30 | +print(params) |
| 31 | + |
| 32 | +############################################################################## |
| 33 | +# The resulting :class:`~.TensorDict` holds the same |
| 34 | +# :class:`~torch.nn.Parameter` objects as the module. We can manipulate them |
| 35 | +# as a batch -- for example, zeroing all parameters at once: |
| 36 | + |
| 37 | +params_zero = params.detach().clone().zero_() |
| 38 | +print("All zeros:", (params_zero == 0).all()) |
| 39 | + |
| 40 | +############################################################################## |
| 41 | +# Swapping parameters with a context manager |
| 42 | +# ------------------------------------------- |
| 43 | +# |
| 44 | +# :meth:`~.TensorDict.to_module` temporarily replaces the parameters of a |
| 45 | +# module within a context manager. The original parameters are restored on |
| 46 | +# exit. |
| 47 | + |
| 48 | +x = torch.randn(5, 3) |
| 49 | + |
| 50 | +with params_zero.to_module(module): |
| 51 | + y_zero = module(x) |
| 52 | + |
| 53 | +print("Output with zeroed params:", y_zero) |
| 54 | +assert (y_zero == 0).all() |
| 55 | + |
| 56 | +y_original = module(x) |
| 57 | +print("Output with original params:", y_original) |
| 58 | +assert not (y_original == 0).all() |
| 59 | + |
| 60 | +############################################################################## |
| 61 | +# Model ensembling with ``torch.vmap`` |
| 62 | +# ------------------------------------- |
| 63 | +# |
| 64 | +# Because :class:`~.TensorDict` supports batching and stacking, we can stack |
| 65 | +# multiple parameter configurations and use :func:`~torch.vmap` to run the |
| 66 | +# model across all of them in a single vectorized call. |
| 67 | + |
| 68 | +params_ones = params.detach().clone().apply_(lambda t: t.fill_(1.0)) |
| 69 | +params_stack = torch.stack([params_zero, params_ones, params]) |
| 70 | + |
| 71 | +print("Stacked params batch_size:", params_stack.batch_size) |
| 72 | + |
| 73 | + |
| 74 | +def call(x, td): |
| 75 | + with td.to_module(module): |
| 76 | + return module(x) |
| 77 | + |
| 78 | + |
| 79 | +x = torch.randn(3, 5, 3) |
| 80 | +y = torch.vmap(call)(x, params_stack) |
| 81 | +print("Output shape:", y.shape) |
| 82 | + |
| 83 | +assert (y[0] == 0).all() |
| 84 | + |
| 85 | +############################################################################## |
| 86 | +# Functional calls with ``torch.func`` |
| 87 | +# -------------------------------------- |
| 88 | +# |
| 89 | +# :func:`~torch.func.functional_call` works with the state-dict extracted |
| 90 | +# by :meth:`~.TensorDict.from_module`. Because ``from_module`` returns a |
| 91 | +# :class:`~.TensorDict` with the same structure as a state-dict, we can |
| 92 | +# convert it to a regular dict and pass it directly. |
| 93 | + |
| 94 | +from torch.func import functional_call |
| 95 | + |
| 96 | +flat_params = params.flatten_keys(".") |
| 97 | +state_dict = dict(flat_params.items()) |
| 98 | +x = torch.randn(5, 3) |
| 99 | +y = functional_call(module, state_dict, x) |
| 100 | +print("functional_call output:", y.shape) |
| 101 | + |
| 102 | +############################################################################## |
| 103 | +# The combination of :meth:`~.TensorDict.from_module`, |
| 104 | +# :meth:`~.TensorDict.to_module`, and :func:`~torch.vmap` makes it |
| 105 | +# straightforward to do things like compute per-sample gradients, run |
| 106 | +# model ensembles, or implement meta-learning inner loops -- all without |
| 107 | +# leaving the standard PyTorch API. |
0 commit comments