Skip to content

Commit 3a0033a

Browse files
committed
Split utils option helpers
1 parent 8c46f0c commit 3a0033a

3 files changed

Lines changed: 295 additions & 365 deletions

File tree

tensordict/_utils_options.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import os
9+
from typing import Any
10+
11+
from torch.utils._contextlib import _DecoratorContextManager
12+
13+
__all__ = [
14+
"_REPR_OPTIONS",
15+
"_legacy_lazy",
16+
"capture_non_tensor_stack",
17+
"get_printoptions",
18+
"lazy_legacy",
19+
"list_to_stack",
20+
"set_capture_non_tensor_stack",
21+
"set_lazy_legacy",
22+
"set_list_to_stack",
23+
"set_printoptions",
24+
]
25+
26+
27+
def _strtobool(val):
28+
val = val.lower()
29+
if val in ("y", "yes", "t", "true", "on", "1"):
30+
return 1
31+
if val in ("n", "no", "f", "false", "off", "0"):
32+
return 0
33+
raise ValueError(f"invalid truth value {val!r}")
34+
35+
36+
_REPR_OPTIONS = {
37+
"show_batch_size": True,
38+
"show_device": True,
39+
"show_is_shared": True,
40+
"show_shape": True,
41+
"show_field_device": True,
42+
"show_dtype": True,
43+
"show_field_is_shared": True,
44+
"show_grad": False,
45+
"show_is_contiguous": False,
46+
"show_is_view": False,
47+
"show_storage_size": False,
48+
"plain": False,
49+
"sort_keys": "alphabetical",
50+
}
51+
52+
_REPR_OPTIONS_KEYS = frozenset(_REPR_OPTIONS)
53+
54+
_VERBOSE_FALSE_OVERRIDES = {
55+
"show_device": False,
56+
"show_is_shared": False,
57+
"show_field_device": False,
58+
"show_dtype": False,
59+
"show_field_is_shared": False,
60+
}
61+
62+
63+
class set_printoptions(_DecoratorContextManager):
64+
"""Controls which attributes appear in TensorDict's ``__repr__`` output."""
65+
66+
def __init__(self, *, verbose: bool = True, **kwargs) -> None:
67+
super().__init__()
68+
unknown = set(kwargs) - _REPR_OPTIONS_KEYS
69+
if unknown:
70+
raise TypeError(
71+
f"Unknown printoptions: {unknown}. Valid options: {sorted(_REPR_OPTIONS_KEYS)}"
72+
)
73+
if not verbose:
74+
merged = dict(_VERBOSE_FALSE_OVERRIDES)
75+
merged.update(kwargs)
76+
kwargs = merged
77+
self._kwargs = kwargs
78+
79+
def clone(self) -> set_printoptions:
80+
return type(self)(**self._kwargs)
81+
82+
def __enter__(self) -> None:
83+
self.set()
84+
85+
def set(self) -> None:
86+
self._old = dict(_REPR_OPTIONS)
87+
_REPR_OPTIONS.update(self._kwargs)
88+
89+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
90+
_REPR_OPTIONS.update(self._old)
91+
92+
93+
def get_printoptions() -> dict:
94+
"""Returns the current TensorDict print options as a dict."""
95+
return dict(_REPR_OPTIONS)
96+
97+
98+
_DEFAULT_LAZY_OP = False
99+
_LAZY_OP = os.environ.get("LAZY_LEGACY_OP")
100+
101+
102+
class set_lazy_legacy(_DecoratorContextManager):
103+
"""Sets the behaviour of some methods to a lazy transform."""
104+
105+
def __init__(self, mode: bool) -> None:
106+
super().__init__()
107+
self.mode = mode
108+
109+
def clone(self) -> set_lazy_legacy:
110+
return type(self)(self.mode)
111+
112+
def __enter__(self) -> None:
113+
self.set()
114+
115+
def set(self) -> None:
116+
global _LAZY_OP
117+
self._old_mode = _LAZY_OP
118+
_LAZY_OP = bool(self.mode)
119+
os.environ["LAZY_LEGACY_OP"] = str(_LAZY_OP)
120+
121+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
122+
global _LAZY_OP
123+
_LAZY_OP = self._old_mode
124+
os.environ["LAZY_LEGACY_OP"] = str(_LAZY_OP)
125+
126+
127+
def lazy_legacy(allow_none=False):
128+
"""Returns `True` if lazy representations will be used for selected methods."""
129+
if _LAZY_OP is None and allow_none:
130+
return None
131+
if _LAZY_OP is None:
132+
return _DEFAULT_LAZY_OP
133+
return _strtobool(_LAZY_OP) if isinstance(_LAZY_OP, str) else _LAZY_OP
134+
135+
136+
def _legacy_lazy(func):
137+
if not func.__name__.startswith("_legacy_"):
138+
raise NameError(
139+
f"The function name {func.__name__} must start with _legacy_ if it's decorated with _legacy_lazy."
140+
)
141+
func.LEGACY = True
142+
return func
143+
144+
145+
_DEFAULT_CAPTURE_NONTENSOR_STACK = False
146+
_CAPTURE_NONTENSOR_STACK = os.environ.get("CAPTURE_NONTENSOR_STACK")
147+
148+
149+
class set_capture_non_tensor_stack(_DecoratorContextManager):
150+
"""Controls whether identical non-tensor data should be captured when stacked."""
151+
152+
def __init__(self, mode: bool) -> None:
153+
super().__init__()
154+
self.mode = mode
155+
156+
def clone(self) -> set_capture_non_tensor_stack:
157+
return type(self)(self.mode)
158+
159+
def __enter__(self) -> None:
160+
self.set()
161+
162+
def set(self) -> None:
163+
global _CAPTURE_NONTENSOR_STACK
164+
self._old_mode = _CAPTURE_NONTENSOR_STACK
165+
_CAPTURE_NONTENSOR_STACK = bool(self.mode)
166+
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)
167+
168+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
169+
global _CAPTURE_NONTENSOR_STACK
170+
_CAPTURE_NONTENSOR_STACK = self._old_mode
171+
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)
172+
173+
174+
def capture_non_tensor_stack(allow_none=False):
175+
"""Get the current setting for capturing non-tensor stacks."""
176+
if _CAPTURE_NONTENSOR_STACK is None and allow_none:
177+
return None
178+
if _CAPTURE_NONTENSOR_STACK is None:
179+
return _DEFAULT_CAPTURE_NONTENSOR_STACK
180+
if (
181+
isinstance(_CAPTURE_NONTENSOR_STACK, str)
182+
and _CAPTURE_NONTENSOR_STACK.lower() == "none"
183+
):
184+
return _DEFAULT_CAPTURE_NONTENSOR_STACK
185+
return (
186+
_strtobool(_CAPTURE_NONTENSOR_STACK)
187+
if isinstance(_CAPTURE_NONTENSOR_STACK, str)
188+
else _CAPTURE_NONTENSOR_STACK
189+
)
190+
191+
192+
_DEFAULT_LIST_TO_STACK = "1"
193+
_LIST_TO_STACK = os.environ.get("LIST_TO_STACK")
194+
195+
196+
class set_list_to_stack(_DecoratorContextManager):
197+
"""Context manager and decorator to control list handling in TensorDict."""
198+
199+
def __init__(self, mode: bool) -> None:
200+
super().__init__()
201+
self.mode = mode
202+
203+
def clone(self) -> set_list_to_stack:
204+
return type(self)(self.mode)
205+
206+
def __enter__(self) -> None:
207+
self.set()
208+
209+
def set(self) -> None:
210+
global _LIST_TO_STACK
211+
self._old_mode = _LIST_TO_STACK
212+
_LIST_TO_STACK = bool(self.mode)
213+
os.environ["LIST_TO_STACK"] = str(_LIST_TO_STACK)
214+
215+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
216+
global _LIST_TO_STACK
217+
_LIST_TO_STACK = self._old_mode
218+
os.environ["LIST_TO_STACK"] = str(_LIST_TO_STACK)
219+
220+
221+
def list_to_stack(allow_none=False):
222+
"""Retrieves the current setting for list-to-stack conversion in TensorDict."""
223+
if _LIST_TO_STACK is None and allow_none:
224+
return None
225+
if _LIST_TO_STACK is None:
226+
return _DEFAULT_LIST_TO_STACK
227+
if isinstance(_LIST_TO_STACK, str) and _LIST_TO_STACK.lower() == "none":
228+
return _DEFAULT_LIST_TO_STACK
229+
return (
230+
_strtobool(_LIST_TO_STACK)
231+
if isinstance(_LIST_TO_STACK, str)
232+
else _LIST_TO_STACK
233+
)
234+
235+
236+
for _name in __all__:
237+
_obj = globals()[_name]
238+
if hasattr(_obj, "__module__"):
239+
_obj.__module__ = "tensordict.utils"

0 commit comments

Comments
 (0)