@@ -1361,27 +1361,62 @@ def istensor(cls):
13611361 return True
13621362
13631363
1364+ def _plain_summary (tensor : Tensor ) -> str :
1365+ """Short value summary for ``plain`` print mode."""
1366+ if isinstance (tensor , UninitializedTensorMixin ) or tensor .numel () == 0 :
1367+ return "[]"
1368+ if tensor .is_floating_point () or tensor .is_complex ():
1369+ return f"mean={ tensor .to (torch .float64 ).mean ().item ():.4g} , std={ tensor .to (torch .float64 ).std ().item ():.4g} "
1370+ return f"min={ tensor .min ().item ()} , max={ tensor .max ().item ()} "
1371+
1372+
1373+ def _tensor_repr_fields (tensor : Tensor ) -> list [str ]:
1374+ """Build the list of ``key=value`` strings for a single tensor descriptor."""
1375+ opts = _REPR_OPTIONS
1376+ parts : list [str ] = []
1377+ if opts ["show_shape" ]:
1378+ parts .append (f"shape={ _shape (tensor )} " )
1379+ if opts ["show_tensor_device" ]:
1380+ parts .append (f"device={ _device (tensor )} " )
1381+ if opts ["show_tensor_dtype" ]:
1382+ parts .append (f"dtype={ _dtype (tensor )} " )
1383+ if opts ["show_tensor_is_shared" ]:
1384+ parts .append (f"is_shared={ _is_shared (tensor )} " )
1385+ if opts ["show_grad" ]:
1386+ parts .append (f"requires_grad={ tensor .requires_grad } " )
1387+ if opts ["show_is_contiguous" ]:
1388+ parts .append (f"is_contiguous={ tensor .is_contiguous ()} " )
1389+ if opts ["show_is_view" ]:
1390+ parts .append (f"is_view={ tensor ._base is not None } " )
1391+ if opts ["show_storage_size" ]:
1392+ parts .append (f"storage_size={ tensor .untyped_storage ().nbytes ()} " )
1393+ if opts ["plain" ]:
1394+ parts .append (_plain_summary (tensor ))
1395+ return parts
1396+
1397+
13641398def _get_repr (tensor : Tensor ) -> str :
1365- s = ", " .join (
1366- [
1367- f"shape={ _shape (tensor )} " ,
1368- f"device={ _device (tensor )} " ,
1369- f"dtype={ _dtype (tensor )} " ,
1370- f"is_shared={ _is_shared (tensor )} " ,
1371- ]
1372- )
1399+ s = ", " .join (_tensor_repr_fields (tensor ))
13731400 return f"{ type (tensor ).__name__ } ({ s } )"
13741401
13751402
1403+ def _tensor_repr_fields_custom (shape , device , dtype , is_shared ) -> list [str ]:
1404+ """Build the list of ``key=value`` strings for a tensor descriptor with explicit values."""
1405+ opts = _REPR_OPTIONS
1406+ parts : list [str ] = []
1407+ if opts ["show_shape" ]:
1408+ parts .append (f"shape={ shape } " )
1409+ if opts ["show_tensor_device" ]:
1410+ parts .append (f"device={ device } " )
1411+ if opts ["show_tensor_dtype" ]:
1412+ parts .append (f"dtype={ dtype } " )
1413+ if opts ["show_tensor_is_shared" ]:
1414+ parts .append (f"is_shared={ is_shared } " )
1415+ return parts
1416+
1417+
13761418def _get_repr_custom (cls , shape , device , dtype , is_shared ) -> str :
1377- s = ", " .join (
1378- [
1379- f"shape={ shape } " ,
1380- f"device={ device } " ,
1381- f"dtype={ dtype } " ,
1382- f"is_shared={ is_shared } " ,
1383- ]
1384- )
1419+ s = ", " .join (_tensor_repr_fields_custom (shape , device , dtype , is_shared ))
13851420 return f"{ cls .__name__ } ({ s } )"
13861421
13871422
@@ -1766,6 +1801,107 @@ def _getitem_batch_size(batch_size, index):
17661801 return torch .Size (out )
17671802
17681803
1804+ # Repr / print options
1805+ _REPR_OPTIONS = {
1806+ "show_batch_size" : True ,
1807+ "show_device" : True ,
1808+ "show_is_shared" : True ,
1809+ "show_shape" : True ,
1810+ "show_tensor_device" : True ,
1811+ "show_tensor_dtype" : True ,
1812+ "show_tensor_is_shared" : True ,
1813+ "show_grad" : False ,
1814+ "show_is_contiguous" : False ,
1815+ "show_is_view" : False ,
1816+ "show_storage_size" : False ,
1817+ "plain" : False ,
1818+ }
1819+
1820+ _REPR_OPTIONS_KEYS = frozenset (_REPR_OPTIONS )
1821+
1822+
1823+ class set_printoptions (_DecoratorContextManager ):
1824+ """Controls which attributes appear in TensorDict's ``__repr__`` output.
1825+
1826+ Can be used as a global setter (via :meth:`set`), a context manager, or a
1827+ decorator. Follows the same pattern as :class:`set_lazy_legacy`.
1828+
1829+ Keyword Args:
1830+ show_batch_size (bool, optional): Show ``batch_size`` in TensorDict repr.
1831+ Defaults to ``True``.
1832+ show_device (bool, optional): Show ``device`` in TensorDict repr.
1833+ Defaults to ``True``.
1834+ show_is_shared (bool, optional): Show ``is_shared`` in TensorDict repr.
1835+ Defaults to ``True``.
1836+ show_shape (bool, optional): Show ``shape`` in per-tensor field descriptors.
1837+ Defaults to ``True``.
1838+ show_tensor_device (bool, optional): Show ``device`` in per-tensor field
1839+ descriptors. Defaults to ``True``.
1840+ show_tensor_dtype (bool, optional): Show ``dtype`` in per-tensor field
1841+ descriptors. Defaults to ``True``.
1842+ show_tensor_is_shared (bool, optional): Show ``is_shared`` in per-tensor
1843+ field descriptors. Defaults to ``True``.
1844+ show_grad (bool, optional): Show ``requires_grad`` in per-tensor field
1845+ descriptors. Defaults to ``False``.
1846+ show_is_contiguous (bool, optional): Show ``is_contiguous`` in per-tensor
1847+ field descriptors. Defaults to ``False``.
1848+ show_is_view (bool, optional): Show ``is_view`` in per-tensor field
1849+ descriptors. Defaults to ``False``.
1850+ show_storage_size (bool, optional): Show ``storage_size`` (in bytes) in
1851+ per-tensor field descriptors. Defaults to ``False``.
1852+ plain (bool, optional): When ``True``, include a short summary of the
1853+ actual tensor values in the field descriptors. Defaults to ``False``.
1854+
1855+ Examples:
1856+ >>> import torch
1857+ >>> from tensordict import TensorDict, set_printoptions
1858+ >>> td = TensorDict({"x": torch.randn(3, 4)})
1859+ >>> # Global
1860+ >>> set_printoptions(show_device=False, show_is_shared=False).set()
1861+ >>> print(td)
1862+ >>> # Context manager
1863+ >>> with set_printoptions(show_tensor_dtype=False):
1864+ ... print(td)
1865+ >>> # Decorator
1866+ >>> @set_printoptions(show_is_shared=False)
1867+ ... def my_func(td):
1868+ ... print(td)
1869+
1870+ """
1871+
1872+ def __init__ (self , ** kwargs : bool ) -> None :
1873+ super ().__init__ ()
1874+ unknown = set (kwargs ) - _REPR_OPTIONS_KEYS
1875+ if unknown :
1876+ raise TypeError (
1877+ f"Unknown printoptions: { unknown } . Valid options: { sorted (_REPR_OPTIONS_KEYS )} "
1878+ )
1879+ self ._kwargs = kwargs
1880+
1881+ def clone (self ) -> set_printoptions :
1882+ return type (self )(** self ._kwargs )
1883+
1884+ def __enter__ (self ) -> None :
1885+ self .set ()
1886+
1887+ def set (self ) -> None :
1888+ global _REPR_OPTIONS
1889+ self ._old = dict (_REPR_OPTIONS )
1890+ _REPR_OPTIONS .update (self ._kwargs )
1891+
1892+ def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
1893+ global _REPR_OPTIONS
1894+ _REPR_OPTIONS .update (self ._old )
1895+
1896+
1897+ def get_printoptions () -> dict :
1898+ """Returns the current TensorDict print options as a dict.
1899+
1900+ See :class:`set_printoptions` for details on each option.
1901+ """
1902+ return dict (_REPR_OPTIONS )
1903+
1904+
17691905# Lazy classes control (legacy feature)
17701906_DEFAULT_LAZY_OP = False
17711907_LAZY_OP = os .environ .get ("LAZY_LEGACY_OP" )
0 commit comments