Skip to content

Commit 706f5d3

Browse files
committed
Split utils key and JSON helpers
1 parent 8c46f0c commit 706f5d3

3 files changed

Lines changed: 191 additions & 186 deletions

File tree

tensordict/_utils_key_json.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import importlib.util
9+
import warnings
10+
11+
from pyvers import get_backend, implement_for, register_backend, set_backend
12+
13+
__all__ = [
14+
"_decode_key_from_filesystem",
15+
"_encode_key_for_filesystem",
16+
"_get_robust_key_setting",
17+
"_get_robust_key_setting_with_warning",
18+
"_json_dumps",
19+
"get_json_backend",
20+
"json_dumps",
21+
"set_json_backend",
22+
]
23+
24+
25+
def _encode_key_for_filesystem(key: str, *, robust: bool = True) -> str:
26+
"""Encode a TensorDict key to be safe for filesystem paths."""
27+
if not robust:
28+
return key
29+
30+
unsafe_chars = set('/<>:"|?*\\ \0%')
31+
unsafe_chars.update(chr(i) for i in range(32))
32+
unsafe_chars.add(chr(127))
33+
34+
encoded_parts = []
35+
for char in key:
36+
if char in unsafe_chars:
37+
encoded_parts.append(f"%{ord(char):02X}")
38+
else:
39+
encoded_parts.append(char)
40+
41+
return "".join(encoded_parts)
42+
43+
44+
def _get_robust_key_setting_with_warning(key: str, robust_key) -> bool:
45+
"""Handle the robust_key parameter with smart deprecation warning."""
46+
if robust_key is not None:
47+
return robust_key
48+
49+
robust_encoded = _encode_key_for_filesystem(key, robust=True)
50+
legacy_encoded = _encode_key_for_filesystem(key, robust=False)
51+
52+
if robust_encoded != legacy_encoded:
53+
warnings.warn(
54+
f"The key '{key}' contains characters that will be handled differently "
55+
f"in TensorDict v0.12 for better cross-platform support. "
56+
f"To opt into the new behavior now, use `robust_key=True`. "
57+
f"To suppress this warning and keep the current behavior, use `robust_key=False`. "
58+
f"See https://github.com/pytorch/tensordict/issues/1440 for details.",
59+
FutureWarning,
60+
stacklevel=3,
61+
)
62+
63+
return False
64+
65+
66+
def _get_robust_key_setting(robust_key) -> bool:
67+
"""Handle the robust_key parameter without key-specific logic."""
68+
if robust_key is None:
69+
return False
70+
return robust_key
71+
72+
73+
def _decode_key_from_filesystem(encoded_key: str) -> str:
74+
"""Decode a filesystem-safe key back to the original TensorDict key."""
75+
decoded_parts = []
76+
i = 0
77+
while i < len(encoded_key):
78+
if encoded_key[i] == "%" and i + 2 < len(encoded_key):
79+
try:
80+
hex_str = encoded_key[i + 1 : i + 3]
81+
char_code = int(hex_str, 16)
82+
decoded_parts.append(chr(char_code))
83+
i += 3
84+
except ValueError:
85+
decoded_parts.append(encoded_key[i])
86+
i += 1
87+
else:
88+
decoded_parts.append(encoded_key[i])
89+
i += 1
90+
91+
return "".join(decoded_parts)
92+
93+
94+
register_backend(group="json", backends={"json": "json", "orjson": "orjson"})
95+
96+
97+
@implement_for("json")
98+
def _json_dumps(data, **kwargs):
99+
"""JSON serialization using standard json module."""
100+
import json
101+
102+
return json.dumps(data, **kwargs)
103+
104+
105+
@implement_for("orjson")
106+
def _json_dumps(data, **kwargs): # noqa: F811
107+
"""JSON serialization using orjson module."""
108+
import orjson
109+
110+
if "separators" in kwargs:
111+
kwargs.pop("separators")
112+
return orjson.dumps(data, **kwargs)
113+
114+
115+
def json_dumps(data, **kwargs):
116+
"""Unified JSON serialization function that works with both json and orjson backends."""
117+
return _json_dumps(data, **kwargs)
118+
119+
120+
def set_json_backend(backend):
121+
"""Set the JSON backend to use (either 'json' or 'orjson')."""
122+
if backend not in ["json", "orjson"]:
123+
raise ValueError("Backend must be either 'json' or 'orjson'")
124+
set_backend("json", backend)
125+
126+
127+
def get_json_backend():
128+
"""Get the current JSON backend."""
129+
return get_backend("json")
130+
131+
132+
if importlib.util.find_spec("orjson") is not None:
133+
set_json_backend("orjson")
134+
else:
135+
set_json_backend("json")
136+
137+
138+
for _name in __all__:
139+
globals()[_name].__module__ = "tensordict.utils"

tensordict/utils.py

Lines changed: 13 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import collections
88
import concurrent.futures
99
import functools
10-
import importlib.util
1110
import itertools
1211
import logging
1312

@@ -39,7 +38,7 @@
3938

4039
import numpy as np
4140
import torch
42-
from pyvers import get_backend, implement_for, register_backend, set_backend
41+
from pyvers import implement_for
4342

4443
from tensordict._C import ( # noqa: F401 # @manual=//pytorch/tensordict:_C
4544
_unravel_key_to_tuple as _unravel_key_to_tuple_cpp,
@@ -2730,145 +2729,12 @@ def unravel_key_list(keys):
27302729
return [unravel_key(key) for key in keys]
27312730

27322731

2733-
def _encode_key_for_filesystem(key: str, *, robust: bool = True) -> str:
2734-
"""Encode a TensorDict key to be safe for filesystem paths.
2735-
2736-
This function provides a bijective mapping from TensorDict keys to
2737-
filesystem-safe filenames by percent-encoding problematic characters.
2738-
2739-
Args:
2740-
key (str): The original TensorDict key
2741-
robust (bool): If True, uses the new robust encoding that percent-encodes
2742-
problematic characters. If False, returns the key unchanged (legacy
2743-
behavior). Defaults to True.
2744-
2745-
Returns:
2746-
str: A filesystem-safe encoded key if robust=True, otherwise the original key
2747-
2748-
Examples:
2749-
>>> _encode_key_for_filesystem("a/b/c")
2750-
"a%2Fb%2Fc"
2751-
>>> _encode_key_for_filesystem("a/b/c", robust=False)
2752-
"a/b/c"
2753-
>>> _encode_key_for_filesystem("normal_key")
2754-
"normal_key"
2755-
"""
2756-
if not robust:
2757-
# Legacy behavior: return key unchanged
2758-
return key
2759-
2760-
# Characters that are problematic across filesystems:
2761-
# - Unix/Linux: / (path separator), null
2762-
# - Windows: < > : " | ? * \ / (path separator), null, control chars
2763-
# - macOS: : (path separator in HFS), null
2764-
# - General: space, percent (for our encoding)
2765-
unsafe_chars = set('/<>:"|?*\\ \0%')
2766-
2767-
# Also encode control characters (0-31) and DEL (127)
2768-
unsafe_chars.update(chr(i) for i in range(32))
2769-
unsafe_chars.add(chr(127))
2770-
2771-
encoded_parts = []
2772-
for char in key:
2773-
if char in unsafe_chars:
2774-
# Percent-encode using uppercase hex for consistency
2775-
encoded_parts.append(f"%{ord(char):02X}")
2776-
else:
2777-
encoded_parts.append(char)
2778-
2779-
return "".join(encoded_parts)
2780-
2781-
2782-
def _get_robust_key_setting_with_warning(key: str, robust_key) -> bool:
2783-
"""Handle the robust_key parameter with smart deprecation warning.
2784-
2785-
Only warns when there's actually a difference between robust and legacy encoding.
2786-
2787-
Args:
2788-
key: The TensorDict key to check
2789-
robust_key: None (auto-detect with warning), False (legacy), True (robust)
2790-
2791-
Returns:
2792-
bool: The effective robust setting to use
2793-
"""
2794-
if robust_key is not None:
2795-
return robust_key
2796-
2797-
# Check if robust encoding would produce a different filename
2798-
robust_encoded = _encode_key_for_filesystem(key, robust=True)
2799-
# Keep this in case we need some futher mapping
2800-
legacy_encoded = _encode_key_for_filesystem(key, robust=False)
2801-
2802-
if robust_encoded != legacy_encoded:
2803-
# Only warn when there's actually a difference
2804-
import warnings
2805-
2806-
warnings.warn(
2807-
f"The key '{key}' contains characters that will be handled differently "
2808-
f"in TensorDict v0.12 for better cross-platform support. "
2809-
f"To opt into the new behavior now, use `robust_key=True`. "
2810-
f"To suppress this warning and keep the current behavior, use `robust_key=False`. "
2811-
f"See https://github.com/pytorch/tensordict/issues/1440 for details.",
2812-
FutureWarning,
2813-
stacklevel=3,
2814-
)
2815-
2816-
# Always use legacy behavior when robust_key=None
2817-
return False
2818-
2819-
2820-
def _get_robust_key_setting(robust_key) -> bool:
2821-
"""Handle the robust_key parameter without key-specific logic.
2822-
2823-
Args:
2824-
robust_key: None (fallback to False), False (legacy), True (robust)
2825-
2826-
Returns:
2827-
bool: The effective robust setting to use
2828-
"""
2829-
if robust_key is None:
2830-
return False
2831-
return robust_key
2832-
2833-
2834-
def _decode_key_from_filesystem(encoded_key: str) -> str:
2835-
"""Decode a filesystem-safe key back to the original TensorDict key.
2836-
2837-
This is the reverse of _encode_key_for_filesystem.
2838-
2839-
Args:
2840-
encoded_key (str): A filesystem-safe encoded key
2841-
2842-
Returns:
2843-
str: The original TensorDict key
2844-
2845-
Examples:
2846-
>>> _decode_key_from_filesystem("a%2Fb%2Fc")
2847-
"a/b/c"
2848-
>>> _decode_key_from_filesystem("key%20with%20spaces")
2849-
"key with spaces"
2850-
>>> _decode_key_from_filesystem("normal_key")
2851-
"normal_key"
2852-
"""
2853-
decoded_parts = []
2854-
i = 0
2855-
while i < len(encoded_key):
2856-
if encoded_key[i] == "%" and i + 2 < len(encoded_key):
2857-
try:
2858-
# Decode the hex value
2859-
hex_str = encoded_key[i + 1 : i + 3]
2860-
char_code = int(hex_str, 16)
2861-
decoded_parts.append(chr(char_code))
2862-
i += 3
2863-
except ValueError:
2864-
# Invalid hex sequence, treat as literal %
2865-
decoded_parts.append(encoded_key[i])
2866-
i += 1
2867-
else:
2868-
decoded_parts.append(encoded_key[i])
2869-
i += 1
2870-
2871-
return "".join(decoded_parts)
2732+
from tensordict._utils_key_json import ( # noqa: F401
2733+
_decode_key_from_filesystem,
2734+
_encode_key_for_filesystem,
2735+
_get_robust_key_setting,
2736+
_get_robust_key_setting_with_warning,
2737+
)
28722738

28732739

28742740
def _slice_indices(index: slice, len: int):
@@ -3302,51 +3168,12 @@ def _create_segments_from_list(
33023168
return splits
33033169

33043170

3305-
# Register JSON backends
3306-
register_backend(group="json", backends={"json": "json", "orjson": "orjson"})
3307-
3308-
3309-
@implement_for("json")
3310-
def _json_dumps(data, **kwargs):
3311-
"""JSON serialization using standard json module."""
3312-
import json
3313-
3314-
return json.dumps(data, **kwargs)
3315-
3316-
3317-
@implement_for("orjson")
3318-
def _json_dumps(data, **kwargs): # noqa: F811
3319-
"""JSON serialization using orjson module."""
3320-
import orjson
3321-
3322-
# orjson doesn't support separators parameter, so we need to handle it differently
3323-
if "separators" in kwargs:
3324-
# Remove separators for orjson and use default compact format
3325-
kwargs.pop("separators")
3326-
return orjson.dumps(data, **kwargs)
3327-
3328-
3329-
def json_dumps(data, **kwargs):
3330-
"""Unified JSON serialization function that works with both json and orjson backends."""
3331-
return _json_dumps(data, **kwargs)
3332-
3333-
3334-
def set_json_backend(backend):
3335-
"""Set the JSON backend to use (either 'json' or 'orjson')."""
3336-
if backend not in ["json", "orjson"]:
3337-
raise ValueError("Backend must be either 'json' or 'orjson'")
3338-
set_backend("json", backend)
3339-
3340-
3341-
def get_json_backend():
3342-
"""Get the current JSON backend."""
3343-
return get_backend("json")
3344-
3345-
3346-
if importlib.util.find_spec("orjson") is not None:
3347-
set_json_backend("orjson")
3348-
else:
3349-
set_json_backend("json")
3171+
from tensordict._utils_key_json import ( # noqa: F401
3172+
_json_dumps,
3173+
get_json_backend,
3174+
json_dumps,
3175+
set_json_backend,
3176+
)
33503177

33513178

33523179
class LinkedList(list):

0 commit comments

Comments
 (0)