|
7 | 7 | import collections |
8 | 8 | import concurrent.futures |
9 | 9 | import functools |
10 | | -import importlib.util |
11 | 10 | import itertools |
12 | 11 | import logging |
13 | 12 |
|
|
39 | 38 |
|
40 | 39 | import numpy as np |
41 | 40 | import torch |
42 | | -from pyvers import get_backend, implement_for, register_backend, set_backend |
| 41 | +from pyvers import implement_for |
43 | 42 |
|
44 | 43 | from tensordict._C import ( # noqa: F401 # @manual=//pytorch/tensordict:_C |
45 | 44 | _unravel_key_to_tuple as _unravel_key_to_tuple_cpp, |
@@ -2730,145 +2729,12 @@ def unravel_key_list(keys): |
2730 | 2729 | return [unravel_key(key) for key in keys] |
2731 | 2730 |
|
2732 | 2731 |
|
2733 | | -def _encode_key_for_filesystem(key: str, *, robust: bool = True) -> str: |
2734 | | - """Encode a TensorDict key to be safe for filesystem paths. |
2735 | | -
|
2736 | | - This function provides a bijective mapping from TensorDict keys to |
2737 | | - filesystem-safe filenames by percent-encoding problematic characters. |
2738 | | -
|
2739 | | - Args: |
2740 | | - key (str): The original TensorDict key |
2741 | | - robust (bool): If True, uses the new robust encoding that percent-encodes |
2742 | | - problematic characters. If False, returns the key unchanged (legacy |
2743 | | - behavior). Defaults to True. |
2744 | | -
|
2745 | | - Returns: |
2746 | | - str: A filesystem-safe encoded key if robust=True, otherwise the original key |
2747 | | -
|
2748 | | - Examples: |
2749 | | - >>> _encode_key_for_filesystem("a/b/c") |
2750 | | - "a%2Fb%2Fc" |
2751 | | - >>> _encode_key_for_filesystem("a/b/c", robust=False) |
2752 | | - "a/b/c" |
2753 | | - >>> _encode_key_for_filesystem("normal_key") |
2754 | | - "normal_key" |
2755 | | - """ |
2756 | | - if not robust: |
2757 | | - # Legacy behavior: return key unchanged |
2758 | | - return key |
2759 | | - |
2760 | | - # Characters that are problematic across filesystems: |
2761 | | - # - Unix/Linux: / (path separator), null |
2762 | | - # - Windows: < > : " | ? * \ / (path separator), null, control chars |
2763 | | - # - macOS: : (path separator in HFS), null |
2764 | | - # - General: space, percent (for our encoding) |
2765 | | - unsafe_chars = set('/<>:"|?*\\ \0%') |
2766 | | - |
2767 | | - # Also encode control characters (0-31) and DEL (127) |
2768 | | - unsafe_chars.update(chr(i) for i in range(32)) |
2769 | | - unsafe_chars.add(chr(127)) |
2770 | | - |
2771 | | - encoded_parts = [] |
2772 | | - for char in key: |
2773 | | - if char in unsafe_chars: |
2774 | | - # Percent-encode using uppercase hex for consistency |
2775 | | - encoded_parts.append(f"%{ord(char):02X}") |
2776 | | - else: |
2777 | | - encoded_parts.append(char) |
2778 | | - |
2779 | | - return "".join(encoded_parts) |
2780 | | - |
2781 | | - |
2782 | | -def _get_robust_key_setting_with_warning(key: str, robust_key) -> bool: |
2783 | | - """Handle the robust_key parameter with smart deprecation warning. |
2784 | | -
|
2785 | | - Only warns when there's actually a difference between robust and legacy encoding. |
2786 | | -
|
2787 | | - Args: |
2788 | | - key: The TensorDict key to check |
2789 | | - robust_key: None (auto-detect with warning), False (legacy), True (robust) |
2790 | | -
|
2791 | | - Returns: |
2792 | | - bool: The effective robust setting to use |
2793 | | - """ |
2794 | | - if robust_key is not None: |
2795 | | - return robust_key |
2796 | | - |
2797 | | - # Check if robust encoding would produce a different filename |
2798 | | - robust_encoded = _encode_key_for_filesystem(key, robust=True) |
2799 | | - # Keep this in case we need some futher mapping |
2800 | | - legacy_encoded = _encode_key_for_filesystem(key, robust=False) |
2801 | | - |
2802 | | - if robust_encoded != legacy_encoded: |
2803 | | - # Only warn when there's actually a difference |
2804 | | - import warnings |
2805 | | - |
2806 | | - warnings.warn( |
2807 | | - f"The key '{key}' contains characters that will be handled differently " |
2808 | | - f"in TensorDict v0.12 for better cross-platform support. " |
2809 | | - f"To opt into the new behavior now, use `robust_key=True`. " |
2810 | | - f"To suppress this warning and keep the current behavior, use `robust_key=False`. " |
2811 | | - f"See https://github.com/pytorch/tensordict/issues/1440 for details.", |
2812 | | - FutureWarning, |
2813 | | - stacklevel=3, |
2814 | | - ) |
2815 | | - |
2816 | | - # Always use legacy behavior when robust_key=None |
2817 | | - return False |
2818 | | - |
2819 | | - |
2820 | | -def _get_robust_key_setting(robust_key) -> bool: |
2821 | | - """Handle the robust_key parameter without key-specific logic. |
2822 | | -
|
2823 | | - Args: |
2824 | | - robust_key: None (fallback to False), False (legacy), True (robust) |
2825 | | -
|
2826 | | - Returns: |
2827 | | - bool: The effective robust setting to use |
2828 | | - """ |
2829 | | - if robust_key is None: |
2830 | | - return False |
2831 | | - return robust_key |
2832 | | - |
2833 | | - |
2834 | | -def _decode_key_from_filesystem(encoded_key: str) -> str: |
2835 | | - """Decode a filesystem-safe key back to the original TensorDict key. |
2836 | | -
|
2837 | | - This is the reverse of _encode_key_for_filesystem. |
2838 | | -
|
2839 | | - Args: |
2840 | | - encoded_key (str): A filesystem-safe encoded key |
2841 | | -
|
2842 | | - Returns: |
2843 | | - str: The original TensorDict key |
2844 | | -
|
2845 | | - Examples: |
2846 | | - >>> _decode_key_from_filesystem("a%2Fb%2Fc") |
2847 | | - "a/b/c" |
2848 | | - >>> _decode_key_from_filesystem("key%20with%20spaces") |
2849 | | - "key with spaces" |
2850 | | - >>> _decode_key_from_filesystem("normal_key") |
2851 | | - "normal_key" |
2852 | | - """ |
2853 | | - decoded_parts = [] |
2854 | | - i = 0 |
2855 | | - while i < len(encoded_key): |
2856 | | - if encoded_key[i] == "%" and i + 2 < len(encoded_key): |
2857 | | - try: |
2858 | | - # Decode the hex value |
2859 | | - hex_str = encoded_key[i + 1 : i + 3] |
2860 | | - char_code = int(hex_str, 16) |
2861 | | - decoded_parts.append(chr(char_code)) |
2862 | | - i += 3 |
2863 | | - except ValueError: |
2864 | | - # Invalid hex sequence, treat as literal % |
2865 | | - decoded_parts.append(encoded_key[i]) |
2866 | | - i += 1 |
2867 | | - else: |
2868 | | - decoded_parts.append(encoded_key[i]) |
2869 | | - i += 1 |
2870 | | - |
2871 | | - return "".join(decoded_parts) |
| 2732 | +from tensordict._utils_key_json import ( # noqa: F401 |
| 2733 | + _decode_key_from_filesystem, |
| 2734 | + _encode_key_for_filesystem, |
| 2735 | + _get_robust_key_setting, |
| 2736 | + _get_robust_key_setting_with_warning, |
| 2737 | +) |
2872 | 2738 |
|
2873 | 2739 |
|
2874 | 2740 | def _slice_indices(index: slice, len: int): |
@@ -3302,51 +3168,12 @@ def _create_segments_from_list( |
3302 | 3168 | return splits |
3303 | 3169 |
|
3304 | 3170 |
|
3305 | | -# Register JSON backends |
3306 | | -register_backend(group="json", backends={"json": "json", "orjson": "orjson"}) |
3307 | | - |
3308 | | - |
3309 | | -@implement_for("json") |
3310 | | -def _json_dumps(data, **kwargs): |
3311 | | - """JSON serialization using standard json module.""" |
3312 | | - import json |
3313 | | - |
3314 | | - return json.dumps(data, **kwargs) |
3315 | | - |
3316 | | - |
3317 | | -@implement_for("orjson") |
3318 | | -def _json_dumps(data, **kwargs): # noqa: F811 |
3319 | | - """JSON serialization using orjson module.""" |
3320 | | - import orjson |
3321 | | - |
3322 | | - # orjson doesn't support separators parameter, so we need to handle it differently |
3323 | | - if "separators" in kwargs: |
3324 | | - # Remove separators for orjson and use default compact format |
3325 | | - kwargs.pop("separators") |
3326 | | - return orjson.dumps(data, **kwargs) |
3327 | | - |
3328 | | - |
3329 | | -def json_dumps(data, **kwargs): |
3330 | | - """Unified JSON serialization function that works with both json and orjson backends.""" |
3331 | | - return _json_dumps(data, **kwargs) |
3332 | | - |
3333 | | - |
3334 | | -def set_json_backend(backend): |
3335 | | - """Set the JSON backend to use (either 'json' or 'orjson').""" |
3336 | | - if backend not in ["json", "orjson"]: |
3337 | | - raise ValueError("Backend must be either 'json' or 'orjson'") |
3338 | | - set_backend("json", backend) |
3339 | | - |
3340 | | - |
3341 | | -def get_json_backend(): |
3342 | | - """Get the current JSON backend.""" |
3343 | | - return get_backend("json") |
3344 | | - |
3345 | | - |
3346 | | -if importlib.util.find_spec("orjson") is not None: |
3347 | | - set_json_backend("orjson") |
3348 | | -else: |
3349 | | - set_json_backend("json") |
| 3171 | +from tensordict._utils_key_json import ( # noqa: F401 |
| 3172 | + _json_dumps, |
| 3173 | + get_json_backend, |
| 3174 | + json_dumps, |
| 3175 | + set_json_backend, |
| 3176 | +) |
3350 | 3177 |
|
3351 | 3178 |
|
3352 | 3179 | class LinkedList(list): |
|
0 commit comments