Skip to content

Commit 9a3c99c

Browse files
committed
[Feature] Add set_printoptions for configurable TensorDict repr
Add set_printoptions / get_printoptions to control which attributes appear in TensorDict's __repr__. Works as a global setter, context manager, or decorator. Options: show/hide batch_size, device, is_shared at the TensorDict level; show/hide shape, device, dtype, is_shared at the per-tensor level; opt-in extended attributes (requires_grad, is_contiguous, is_view, storage_size, plain value summary). Also adds a "Printing and Display" documentation page covering the feature, the motivation behind TensorDict's metadata-first repr, and parse_tensor_dict_string. Made-with: Cursor
1 parent e214613 commit 9a3c99c

File tree

6 files changed

+183
-6
lines changed

6 files changed

+183
-6
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Contents
108108
distributed
109109
fx
110110
saving
111+
printing
111112
reference/index
112113

113114
Indices and tables

docs/source/printing.rst

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
Printing and Display
2+
====================
3+
4+
Why printing a TensorDict is more useful than printing a tensor
5+
---------------------------------------------------------------
6+
7+
When working with PyTorch tensors, calling ``print(tensor)`` dumps the raw
8+
numerical content. In practice, however, most ``print`` calls during debugging
9+
are motivated by a single question: *what does this tensor look like?* You want
10+
its shape, dtype, device, maybe whether it requires a gradient -- not a wall of
11+
floating-point numbers.
12+
13+
Because a :class:`~tensordict.TensorDict` groups multiple tensors under named
14+
keys, its ``__repr__`` gives you exactly that -- a structured, at-a-glance
15+
summary of every tensor it contains:
16+
17+
>>> import torch
18+
>>> from tensordict import TensorDict
19+
>>> td = TensorDict(
20+
... image=torch.randn(32, 3, 64, 64),
21+
... label=torch.randint(10, (32,)),
22+
... batch_size=[32],
23+
... )
24+
>>> print(td)
25+
TensorDict(
26+
fields={
27+
image: Tensor(shape=torch.Size([32, 3, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
28+
label: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False)},
29+
batch_size=torch.Size([32]),
30+
device=None,
31+
is_shared=False)
32+
33+
No data is printed, no truncation ellipses, no guessing at dimensionality.
34+
One glance tells you the names, shapes, dtypes and devices of everything in
35+
the batch.
36+
37+
Configuring the display with ``set_printoptions``
38+
-------------------------------------------------
39+
40+
By default, every attribute is shown for backward compatibility. In many
41+
situations, though, some of those attributes are noise. For instance, if all
42+
your work is on CPU and nothing is shared, ``device=cpu`` and
43+
``is_shared=False`` are repeated on every line without adding information.
44+
45+
:class:`~tensordict.set_printoptions` lets you control exactly which attributes
46+
appear. It works as a **global setter**, a **context manager** or a
47+
**decorator**, following the same pattern as :class:`~tensordict.set_lazy_legacy`
48+
and :func:`torch.set_printoptions`.
49+
50+
Global configuration
51+
~~~~~~~~~~~~~~~~~~~~
52+
53+
Call :meth:`~tensordict.set_printoptions.set` to change the defaults for the
54+
rest of the process:
55+
56+
>>> from tensordict import set_printoptions
57+
>>> set_printoptions(show_device=False, show_is_shared=False).set()
58+
>>> print(td)
59+
TensorDict(
60+
fields={
61+
image: Tensor(shape=torch.Size([32, 3, 64, 64]), dtype=torch.float32),
62+
label: Tensor(shape=torch.Size([32]), dtype=torch.int64)},
63+
batch_size=torch.Size([32]))
64+
65+
Scoped configuration (context manager)
66+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67+
68+
Use the context-manager form when you only want the change for a specific
69+
block of code. The previous settings are automatically restored on exit:
70+
71+
>>> from tensordict import set_printoptions
72+
>>> with set_printoptions(show_tensor_dtype=False, show_is_shared=False):
73+
... print(td) # dtype and is_shared hidden
74+
>>> print(td) # back to defaults
75+
76+
Decorator
77+
~~~~~~~~~
78+
79+
You can also decorate a function so that every ``repr`` call inside it uses
80+
the specified options:
81+
82+
>>> @set_printoptions(show_is_shared=False)
83+
... def summarise(td):
84+
... print(td)
85+
86+
Available options
87+
~~~~~~~~~~~~~~~~~
88+
89+
**TensorDict-level** (these control lines in the outer ``TensorDict(...)``
90+
block):
91+
92+
==================== =========== ===================================================
93+
Option Default Description
94+
==================== =========== ===================================================
95+
``show_batch_size`` ``True`` Show the ``batch_size=`` line.
96+
``show_device`` ``True`` Show the ``device=`` line.
97+
``show_is_shared`` ``True`` Show the ``is_shared=`` line.
98+
==================== =========== ===================================================
99+
100+
**Tensor-level** (these control what appears inside each
101+
``Tensor(...)`` field descriptor):
102+
103+
========================== =========== ============================================
104+
Option Default Description
105+
========================== =========== ============================================
106+
``show_shape`` ``True`` Show the ``shape=`` attribute.
107+
``show_tensor_device`` ``True`` Show the ``device=`` attribute.
108+
``show_tensor_dtype`` ``True`` Show the ``dtype=`` attribute.
109+
``show_tensor_is_shared`` ``True`` Show the ``is_shared=`` attribute.
110+
========================== =========== ============================================
111+
112+
**Extended attributes** (off by default -- opt-in for deeper debugging):
113+
114+
====================== =========== =================================================
115+
Option Default Description
116+
====================== =========== =================================================
117+
``show_grad`` ``False`` Show ``requires_grad=``.
118+
``show_is_contiguous`` ``False`` Show ``is_contiguous=``.
119+
``show_is_view`` ``False`` Show ``is_view=`` (whether ``._base`` is set).
120+
``show_storage_size`` ``False`` Show ``storage_size=`` (bytes).
121+
``plain`` ``False`` Append a short value summary
122+
(mean/std for floats, min/max for ints).
123+
====================== =========== =================================================
124+
125+
Querying the current settings
126+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127+
128+
:func:`~tensordict.get_printoptions` returns a dict with the current values:
129+
130+
>>> from tensordict import get_printoptions
131+
>>> get_printoptions()
132+
{'show_batch_size': True, 'show_device': True, ...}
133+
134+
135+
Reconstructing a TensorDict from its printed representation
136+
-----------------------------------------------------------
137+
138+
When debugging, it is common to receive a TensorDict repr as a string -- for
139+
example, pasted from a log file or a colleague's terminal.
140+
:func:`~tensordict.parse_tensor_dict_string` can reconstruct a dummy
141+
:class:`~tensordict.TensorDict` from that string. The resulting object has the
142+
correct structure, batch size, device and dtypes, but all tensor values are
143+
replaced by zeros (since the repr does not contain actual data):
144+
145+
>>> from tensordict import parse_tensor_dict_string
146+
>>> s = """TensorDict(
147+
... fields={
148+
... image: Tensor(shape=torch.Size([32, 3, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
149+
... label: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False)},
150+
... batch_size=torch.Size([32]),
151+
... device=cpu,
152+
... is_shared=False)"""
153+
>>> td = parse_tensor_dict_string(s)
154+
>>> td.batch_size
155+
torch.Size([32])
156+
>>> td["image"].shape
157+
torch.Size([32, 3, 64, 64])
158+
159+
.. note::
160+
161+
:func:`~tensordict.parse_tensor_dict_string` currently only works with the
162+
default (``plain``) print format -- the one that includes ``shape=``,
163+
``device=``, ``dtype=`` and ``is_shared=`` for every field.
164+
If attributes have been hidden via :class:`~tensordict.set_printoptions`,
165+
the regex parser will not find the expected fields and reconstruction will
166+
fail. Support for non-default formats will be added in a follow-up PR.

docs/source/reference/td.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,6 @@ Utils
288288
set_capture_non_tensor_stack
289289
set_lazy_legacy
290290
set_list_to_stack
291+
set_printoptions
292+
get_printoptions
291293
list_to_stack

tensordict/_lazy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3515,9 +3515,7 @@ def __repr__(self):
35153515
fields = _td_fields(self)
35163516
parts = [
35173517
indent(f"fields={{{fields}}}", 4 * " "),
3518-
indent(
3519-
f"exclusive_fields={{{self._repr_exclusive_fields()}}}", 4 * " "
3520-
),
3518+
indent(f"exclusive_fields={{{self._repr_exclusive_fields()}}}", 4 * " "),
35213519
]
35223520
if _REPR_OPTIONS["show_batch_size"]:
35233521
parts.append(indent(f"batch_size={self.batch_size}", 4 * " "))

tensordict/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
_prefix_last_key,
8282
_proc_init,
8383
_prune_selected_keys,
84-
_REPR_OPTIONS,
8584
_rebuild_njt_from_njt,
85+
_REPR_OPTIONS,
8686
_set_max_batch_size,
8787
_shape,
8888
_split_tensordict,

test/test_tensordict.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10194,7 +10194,12 @@ def test_hide_multiple(self):
1019410194
from tensordict import set_printoptions
1019510195

1019610196
td = TensorDict({"a": torch.randn(3, 4)})
10197-
with set_printoptions(show_device=False, show_is_shared=False, show_tensor_dtype=False, show_tensor_is_shared=False):
10197+
with set_printoptions(
10198+
show_device=False,
10199+
show_is_shared=False,
10200+
show_tensor_dtype=False,
10201+
show_tensor_is_shared=False,
10202+
):
1019810203
r = repr(td)
1019910204
assert "\n device=" not in r
1020010205
assert "\n is_shared=" not in r
@@ -10286,7 +10291,12 @@ class MyClass:
1028610291
x: torch.Tensor
1028710292

1028810293
obj = MyClass(x=torch.randn(3, 4), batch_size=[3])
10289-
with set_printoptions(show_device=False, show_is_shared=False, show_tensor_device=False, show_tensor_is_shared=False):
10294+
with set_printoptions(
10295+
show_device=False,
10296+
show_is_shared=False,
10297+
show_tensor_device=False,
10298+
show_tensor_is_shared=False,
10299+
):
1029010300
r = repr(obj)
1029110301
assert "MyClass(" in r
1029210302
assert "device=" not in r

0 commit comments

Comments
 (0)