Skip to content

Commit e214613

Browse files
committed
init
1 parent 9b15864 commit e214613

File tree

6 files changed

+403
-57
lines changed

6 files changed

+403
-57
lines changed

tensordict/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
assert_allclose_td,
6161
assert_close,
6262
capture_non_tensor_stack,
63+
get_printoptions,
6364
is_batchedtensor,
6465
is_non_tensor,
6566
is_tensorclass,
@@ -69,6 +70,7 @@
6970
set_capture_non_tensor_stack,
7071
set_lazy_legacy,
7172
set_list_to_stack,
73+
set_printoptions,
7274
unravel_key,
7375
unravel_key_list,
7476
)
@@ -159,6 +161,8 @@
159161
"set_lazy_legacy",
160162
"list_to_stack",
161163
"set_list_to_stack",
164+
"set_printoptions",
165+
"get_printoptions",
162166
# TensorClass components
163167
"tensorclass",
164168
"MetaData",

tensordict/_lazy.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_parse_to,
6767
_recursive_unbind_list,
6868
_renamed_inplace_method,
69+
_REPR_OPTIONS,
6970
_shape,
7071
_td_fields,
7172
_unravel_key_to_tuple,
@@ -3512,24 +3513,20 @@ def _propagate_unlock(self):
35123513

35133514
def __repr__(self):
35143515
fields = _td_fields(self)
3515-
field_str = indent(f"fields={{{fields}}}", 4 * " ")
3516-
exclusive_fields_str = indent(
3517-
f"exclusive_fields={{{self._repr_exclusive_fields()}}}", 4 * " "
3518-
)
3519-
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
3520-
device_str = indent(f"device={self.device}", 4 * " ")
3521-
is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ")
3522-
stack_dim = indent(f"stack_dim={self.stack_dim}", 4 * " ")
3523-
string = ",\n".join(
3524-
[
3525-
field_str,
3526-
exclusive_fields_str,
3527-
batch_size_str,
3528-
device_str,
3529-
is_shared_str,
3530-
stack_dim,
3531-
]
3532-
)
3516+
parts = [
3517+
indent(f"fields={{{fields}}}", 4 * " "),
3518+
indent(
3519+
f"exclusive_fields={{{self._repr_exclusive_fields()}}}", 4 * " "
3520+
),
3521+
]
3522+
if _REPR_OPTIONS["show_batch_size"]:
3523+
parts.append(indent(f"batch_size={self.batch_size}", 4 * " "))
3524+
if _REPR_OPTIONS["show_device"]:
3525+
parts.append(indent(f"device={self.device}", 4 * " "))
3526+
if _REPR_OPTIONS["show_is_shared"]:
3527+
parts.append(indent(f"is_shared={self.is_shared()}", 4 * " "))
3528+
parts.append(indent(f"stack_dim={self.stack_dim}", 4 * " "))
3529+
string = ",\n".join(parts)
35333530
return f"{type(self).__name__}(\n{string})"
35343531

35353532
def _exclusive_keys(self):

tensordict/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
_prefix_last_key,
8282
_proc_init,
8383
_prune_selected_keys,
84+
_REPR_OPTIONS,
8485
_rebuild_njt_from_njt,
8586
_set_max_batch_size,
8687
_shape,
@@ -553,11 +554,14 @@ def __lt__(self, other: object) -> Self:
553554
def __repr__(self) -> str:
554555
try:
555556
fields = _td_fields(self)
556-
field_str = indent(f"fields={{{fields}}}", 4 * " ")
557-
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
558-
device_str = indent(f"device={self.device}", 4 * " ")
559-
is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ")
560-
string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str])
557+
parts = [indent(f"fields={{{fields}}}", 4 * " ")]
558+
if _REPR_OPTIONS["show_batch_size"]:
559+
parts.append(indent(f"batch_size={self.batch_size}", 4 * " "))
560+
if _REPR_OPTIONS["show_device"]:
561+
parts.append(indent(f"device={self.device}", 4 * " "))
562+
if _REPR_OPTIONS["show_is_shared"]:
563+
parts.append(indent(f"is_shared={self.is_shared()}", 4 * " "))
564+
string = ",\n".join(parts)
561565
except AttributeError:
562566
# When using torch.compile, an exception may be raised with a tensordict object
563567
# that has no attribute (no _tensordict or no _batch_size).

tensordict/tensorclass.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
_is_json_serializable,
6464
_is_tensorclass,
6565
_LOCK_ERROR,
66+
_REPR_OPTIONS,
6667
_td_fields,
6768
_TENSORCLASS_MEMO,
6869
_unravel_key_to_tuple,
@@ -2274,34 +2275,31 @@ def _repr(self) -> str:
22742275
field_str = [fields] if fields else []
22752276
non_tensor_fields = _all_non_td_fields_as_str(self._non_tensordict)
22762277

2277-
medatada_fields = []
2278+
metadata_fields = []
22782279

2279-
if "batch_size" not in self.__expected_keys__:
2280-
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
2281-
medatada_fields.append(batch_size_str)
2282-
elif "shape" not in self.__expected_keys__:
2283-
batch_size_str = indent(f"shape={self.shape}", 4 * " ")
2284-
medatada_fields.append(batch_size_str)
2285-
if "device" not in self.__expected_keys__:
2286-
device_str = indent(f"device={self.device}", 4 * " ")
2287-
medatada_fields.append(device_str)
2288-
2289-
is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ")
2290-
medatada_fields.append(is_shared_str)
2280+
if _REPR_OPTIONS["show_batch_size"]:
2281+
if "batch_size" not in self.__expected_keys__:
2282+
metadata_fields.append(indent(f"batch_size={self.batch_size}", 4 * " "))
2283+
elif "shape" not in self.__expected_keys__:
2284+
metadata_fields.append(indent(f"shape={self.shape}", 4 * " "))
2285+
if _REPR_OPTIONS["show_device"] and "device" not in self.__expected_keys__:
2286+
metadata_fields.append(indent(f"device={self.device}", 4 * " "))
2287+
if _REPR_OPTIONS["show_is_shared"]:
2288+
metadata_fields.append(indent(f"is_shared={self.is_shared()}", 4 * " "))
22912289

22922290
if len(non_tensor_fields) > 0:
22932291
non_tensor_field_str = indent(
22942292
",\n".join(non_tensor_fields),
22952293
4 * " ",
22962294
)
22972295
if field_str:
2298-
string = ",\n".join(field_str + [non_tensor_field_str, *medatada_fields])
2296+
string = ",\n".join(field_str + [non_tensor_field_str, *metadata_fields])
22992297
else:
2300-
string = ",\n".join([non_tensor_field_str, *medatada_fields])
2298+
string = ",\n".join([non_tensor_field_str, *metadata_fields])
23012299
elif field_str:
2302-
string = ",\n".join(field_str + medatada_fields)
2303-
elif len(medatada_fields) > 0:
2304-
string = ",\n".join(medatada_fields)
2300+
string = ",\n".join(field_str + metadata_fields)
2301+
elif len(metadata_fields) > 0:
2302+
string = ",\n".join(metadata_fields)
23052303
else:
23062304
string = ""
23072305
return f"{type(self).__name__}({string})"

tensordict/utils.py

Lines changed: 152 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,27 +1361,62 @@ def istensor(cls):
13611361
return True
13621362

13631363

1364+
def _plain_summary(tensor: Tensor) -> str:
1365+
"""Short value summary for ``plain`` print mode."""
1366+
if isinstance(tensor, UninitializedTensorMixin) or tensor.numel() == 0:
1367+
return "[]"
1368+
if tensor.is_floating_point() or tensor.is_complex():
1369+
return f"mean={tensor.to(torch.float64).mean().item():.4g}, std={tensor.to(torch.float64).std().item():.4g}"
1370+
return f"min={tensor.min().item()}, max={tensor.max().item()}"
1371+
1372+
1373+
def _tensor_repr_fields(tensor: Tensor) -> list[str]:
1374+
"""Build the list of ``key=value`` strings for a single tensor descriptor."""
1375+
opts = _REPR_OPTIONS
1376+
parts: list[str] = []
1377+
if opts["show_shape"]:
1378+
parts.append(f"shape={_shape(tensor)}")
1379+
if opts["show_tensor_device"]:
1380+
parts.append(f"device={_device(tensor)}")
1381+
if opts["show_tensor_dtype"]:
1382+
parts.append(f"dtype={_dtype(tensor)}")
1383+
if opts["show_tensor_is_shared"]:
1384+
parts.append(f"is_shared={_is_shared(tensor)}")
1385+
if opts["show_grad"]:
1386+
parts.append(f"requires_grad={tensor.requires_grad}")
1387+
if opts["show_is_contiguous"]:
1388+
parts.append(f"is_contiguous={tensor.is_contiguous()}")
1389+
if opts["show_is_view"]:
1390+
parts.append(f"is_view={tensor._base is not None}")
1391+
if opts["show_storage_size"]:
1392+
parts.append(f"storage_size={tensor.untyped_storage().nbytes()}")
1393+
if opts["plain"]:
1394+
parts.append(_plain_summary(tensor))
1395+
return parts
1396+
1397+
13641398
def _get_repr(tensor: Tensor) -> str:
1365-
s = ", ".join(
1366-
[
1367-
f"shape={_shape(tensor)}",
1368-
f"device={_device(tensor)}",
1369-
f"dtype={_dtype(tensor)}",
1370-
f"is_shared={_is_shared(tensor)}",
1371-
]
1372-
)
1399+
s = ", ".join(_tensor_repr_fields(tensor))
13731400
return f"{type(tensor).__name__}({s})"
13741401

13751402

1403+
def _tensor_repr_fields_custom(shape, device, dtype, is_shared) -> list[str]:
1404+
"""Build the list of ``key=value`` strings for a tensor descriptor with explicit values."""
1405+
opts = _REPR_OPTIONS
1406+
parts: list[str] = []
1407+
if opts["show_shape"]:
1408+
parts.append(f"shape={shape}")
1409+
if opts["show_tensor_device"]:
1410+
parts.append(f"device={device}")
1411+
if opts["show_tensor_dtype"]:
1412+
parts.append(f"dtype={dtype}")
1413+
if opts["show_tensor_is_shared"]:
1414+
parts.append(f"is_shared={is_shared}")
1415+
return parts
1416+
1417+
13761418
def _get_repr_custom(cls, shape, device, dtype, is_shared) -> str:
1377-
s = ", ".join(
1378-
[
1379-
f"shape={shape}",
1380-
f"device={device}",
1381-
f"dtype={dtype}",
1382-
f"is_shared={is_shared}",
1383-
]
1384-
)
1419+
s = ", ".join(_tensor_repr_fields_custom(shape, device, dtype, is_shared))
13851420
return f"{cls.__name__}({s})"
13861421

13871422

@@ -1766,6 +1801,107 @@ def _getitem_batch_size(batch_size, index):
17661801
return torch.Size(out)
17671802

17681803

1804+
# Repr / print options
1805+
_REPR_OPTIONS = {
1806+
"show_batch_size": True,
1807+
"show_device": True,
1808+
"show_is_shared": True,
1809+
"show_shape": True,
1810+
"show_tensor_device": True,
1811+
"show_tensor_dtype": True,
1812+
"show_tensor_is_shared": True,
1813+
"show_grad": False,
1814+
"show_is_contiguous": False,
1815+
"show_is_view": False,
1816+
"show_storage_size": False,
1817+
"plain": False,
1818+
}
1819+
1820+
_REPR_OPTIONS_KEYS = frozenset(_REPR_OPTIONS)
1821+
1822+
1823+
class set_printoptions(_DecoratorContextManager):
1824+
"""Controls which attributes appear in TensorDict's ``__repr__`` output.
1825+
1826+
Can be used as a global setter (via :meth:`set`), a context manager, or a
1827+
decorator. Follows the same pattern as :class:`set_lazy_legacy`.
1828+
1829+
Keyword Args:
1830+
show_batch_size (bool, optional): Show ``batch_size`` in TensorDict repr.
1831+
Defaults to ``True``.
1832+
show_device (bool, optional): Show ``device`` in TensorDict repr.
1833+
Defaults to ``True``.
1834+
show_is_shared (bool, optional): Show ``is_shared`` in TensorDict repr.
1835+
Defaults to ``True``.
1836+
show_shape (bool, optional): Show ``shape`` in per-tensor field descriptors.
1837+
Defaults to ``True``.
1838+
show_tensor_device (bool, optional): Show ``device`` in per-tensor field
1839+
descriptors. Defaults to ``True``.
1840+
show_tensor_dtype (bool, optional): Show ``dtype`` in per-tensor field
1841+
descriptors. Defaults to ``True``.
1842+
show_tensor_is_shared (bool, optional): Show ``is_shared`` in per-tensor
1843+
field descriptors. Defaults to ``True``.
1844+
show_grad (bool, optional): Show ``requires_grad`` in per-tensor field
1845+
descriptors. Defaults to ``False``.
1846+
show_is_contiguous (bool, optional): Show ``is_contiguous`` in per-tensor
1847+
field descriptors. Defaults to ``False``.
1848+
show_is_view (bool, optional): Show ``is_view`` in per-tensor field
1849+
descriptors. Defaults to ``False``.
1850+
show_storage_size (bool, optional): Show ``storage_size`` (in bytes) in
1851+
per-tensor field descriptors. Defaults to ``False``.
1852+
plain (bool, optional): When ``True``, include a short summary of the
1853+
actual tensor values in the field descriptors. Defaults to ``False``.
1854+
1855+
Examples:
1856+
>>> import torch
1857+
>>> from tensordict import TensorDict, set_printoptions
1858+
>>> td = TensorDict({"x": torch.randn(3, 4)})
1859+
>>> # Global
1860+
>>> set_printoptions(show_device=False, show_is_shared=False).set()
1861+
>>> print(td)
1862+
>>> # Context manager
1863+
>>> with set_printoptions(show_tensor_dtype=False):
1864+
... print(td)
1865+
>>> # Decorator
1866+
>>> @set_printoptions(show_is_shared=False)
1867+
... def my_func(td):
1868+
... print(td)
1869+
1870+
"""
1871+
1872+
def __init__(self, **kwargs: bool) -> None:
1873+
super().__init__()
1874+
unknown = set(kwargs) - _REPR_OPTIONS_KEYS
1875+
if unknown:
1876+
raise TypeError(
1877+
f"Unknown printoptions: {unknown}. Valid options: {sorted(_REPR_OPTIONS_KEYS)}"
1878+
)
1879+
self._kwargs = kwargs
1880+
1881+
def clone(self) -> set_printoptions:
1882+
return type(self)(**self._kwargs)
1883+
1884+
def __enter__(self) -> None:
1885+
self.set()
1886+
1887+
def set(self) -> None:
1888+
global _REPR_OPTIONS
1889+
self._old = dict(_REPR_OPTIONS)
1890+
_REPR_OPTIONS.update(self._kwargs)
1891+
1892+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1893+
global _REPR_OPTIONS
1894+
_REPR_OPTIONS.update(self._old)
1895+
1896+
1897+
def get_printoptions() -> dict:
1898+
"""Returns the current TensorDict print options as a dict.
1899+
1900+
See :class:`set_printoptions` for details on each option.
1901+
"""
1902+
return dict(_REPR_OPTIONS)
1903+
1904+
17691905
# Lazy classes control (legacy feature)
17701906
_DEFAULT_LAZY_OP = False
17711907
_LAZY_OP = os.environ.get("LAZY_LEGACY_OP")

0 commit comments

Comments
 (0)