Skip to content

Commit 3352c56

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add stdlib json fallback for free-threading compatibility (#365)
Summary: Pull Request resolved: #365 D97564498 introduced orjson as a hard dependency, which broke 40+ tests on CPython 3.14 free-threading builds because orjson does not yet support free-threading (PEP 703). This diff adds a `_json_compat` wrapper module that: - Uses orjson when available (for performance on standard CPython) - Falls back to stdlib json when orjson is unavailable (e.g., free-threading builds) - Provides a unified API: `loads()`, `dumps()`, and `JSONDecodeError` - Returns `str` from `dumps()` (no `.decode()` needed at call sites) All 18 files that previously imported orjson directly now import from `tritonparse._json_compat` instead. Two files (`sourcemap_utils.py` and `reproducer/utils.py`) intentionally keep stdlib json and are unchanged. Also moves orjson from a required dependency to an optional extra in `pyproject.toml` for OSS compatibility. Reviewed By: xuzhao9 Differential Revision: D98194947 fbshipit-source-id: bd9ff863a3156514b346a417403da18d390a5e22
1 parent 10c2aa0 commit 3352c56

18 files changed

Lines changed: 177 additions & 149 deletions

tritonparse/_json_compat.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
"""
4+
JSON compatibility layer for tritonparse.
5+
6+
Provides a unified interface that uses orjson when available (for performance)
7+
and falls back to stdlib json (for environments where orjson is unavailable,
8+
e.g. CPython 3.14 free-threading builds).
9+
10+
``loads()`` accepts ``str | bytes | bytearray | memoryview`` inputs.
11+
``dumps()`` returns ``str``.
12+
"""
13+
14+
try:
15+
import orjson as _orjson
16+
17+
_HAS_ORJSON = True
18+
except ImportError:
19+
import json as _json
20+
21+
_HAS_ORJSON = False
22+
23+
24+
def _coerce_keys(obj):
25+
"""Recursively convert non-string dict keys to strings.
26+
27+
stdlib ``json.dumps`` raises ``TypeError`` on non-string keys, whereas
28+
orjson's ``OPT_NON_STR_KEYS`` converts them automatically. This helper
29+
replicates that behavior for the fallback path.
30+
"""
31+
if isinstance(obj, dict):
32+
return {str(k): _coerce_keys(v) for k, v in obj.items()}
33+
if isinstance(obj, list):
34+
return [_coerce_keys(v) for v in obj]
35+
return obj
36+
37+
38+
if _HAS_ORJSON:
39+
JSONDecodeError = _orjson.JSONDecodeError
40+
41+
def loads(data):
42+
"""Deserialize JSON string/bytes to a Python object."""
43+
return _orjson.loads(data)
44+
45+
def dumps(obj, *, indent=False, sort_keys=False):
46+
"""Serialize a Python object to a JSON ``str``.
47+
48+
Args:
49+
obj: The object to serialize.
50+
indent: If True, pretty-print with 2-space indent.
51+
sort_keys: If True, sort dictionary keys.
52+
"""
53+
option = _orjson.OPT_NON_STR_KEYS
54+
if indent:
55+
option |= _orjson.OPT_INDENT_2
56+
if sort_keys:
57+
option |= _orjson.OPT_SORT_KEYS
58+
return _orjson.dumps(obj, option=option).decode()
59+
60+
else:
61+
from json import JSONDecodeError # noqa: F401
62+
63+
def loads(data):
64+
"""Deserialize JSON string/bytes to a Python object."""
65+
if isinstance(data, (bytes, bytearray, memoryview)):
66+
data = (
67+
bytes(data).decode() if isinstance(data, memoryview) else data.decode()
68+
)
69+
return _json.loads(data)
70+
71+
def dumps(obj, *, indent=False, sort_keys=False):
72+
"""Serialize a Python object to a JSON ``str``.
73+
74+
Args:
75+
obj: The object to serialize.
76+
indent: If True, pretty-print with 2-space indent.
77+
sort_keys: If True, sort dictionary keys.
78+
"""
79+
obj = _coerce_keys(obj)
80+
kwargs = {"ensure_ascii": False}
81+
if indent:
82+
kwargs["indent"] = 2
83+
else:
84+
kwargs["separators"] = (",", ":")
85+
if sort_keys:
86+
kwargs["sort_keys"] = True
87+
return _json.dumps(obj, **kwargs)

tritonparse/ai/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from typing import Any, Iterator, List, Optional, Tuple
1818

19-
import orjson
19+
from tritonparse._json_compat import JSONDecodeError, loads
2020

2121

2222
@dataclass
@@ -292,9 +292,9 @@ def chat(
292292
stdout_snippet = ""
293293
if result and result.stdout:
294294
try:
295-
data = orjson.loads(result.stdout)
295+
data = loads(result.stdout)
296296
stdout_snippet = data.get("result", result.stdout[:500])
297-
except (orjson.JSONDecodeError, AttributeError):
297+
except (JSONDecodeError, AttributeError):
298298
stdout_snippet = result.stdout[:500]
299299
stderr_snippet = (
300300
result.stderr[:500] if result and result.stderr else "(empty)"
@@ -384,7 +384,7 @@ def chat_stream(
384384
continue
385385

386386
try:
387-
event = orjson.loads(line)
387+
event = loads(line)
388388
event_type = event.get("type")
389389

390390
if event_type == "assistant":
@@ -398,7 +398,7 @@ def chat_stream(
398398
if "session_id" in event:
399399
self.session_id = event["session_id"]
400400

401-
except orjson.JSONDecodeError:
401+
except JSONDecodeError:
402402
continue
403403

404404
process.wait(timeout=self.timeout)
@@ -441,14 +441,14 @@ def _parse_response(self, stdout: str) -> Response:
441441
Parsed Response object
442442
"""
443443
try:
444-
data = orjson.loads(stdout)
444+
data = loads(stdout)
445445
self.session_id = data.get("session_id")
446446
return Response(
447447
content=data.get("result", stdout),
448448
session_id=self.session_id,
449449
cost_usd=data.get("total_cost_usd"),
450450
raw=data,
451451
)
452-
except orjson.JSONDecodeError:
452+
except JSONDecodeError:
453453
# Non-JSON output, return as-is
454454
return Response(content=stdout.strip())

tritonparse/ai/parsers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
from typing import Optional
1414

15-
import orjson
15+
from tritonparse._json_compat import JSONDecodeError, loads
1616

1717

1818
def extract_json(text: str) -> Optional[dict]:
@@ -34,17 +34,17 @@ def extract_json(text: str) -> Optional[dict]:
3434

3535
# Try direct JSON parsing
3636
try:
37-
return orjson.loads(text.strip())
38-
except orjson.JSONDecodeError:
37+
return loads(text.strip())
38+
except JSONDecodeError:
3939
pass
4040

4141
# Try extracting from markdown code block
4242
pattern = r"```(?:json)?\s*\n?(.*?)\n?```"
4343
match = re.search(pattern, text, re.DOTALL)
4444
if match:
4545
try:
46-
return orjson.loads(match.group(1).strip())
47-
except orjson.JSONDecodeError:
46+
return loads(match.group(1).strip())
47+
except JSONDecodeError:
4848
pass
4949

5050
return None

tritonparse/bisect/state.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pathlib import Path
1515
from typing import Any, Dict, Optional
1616

17-
import orjson
17+
from tritonparse._json_compat import dumps, loads
1818

1919

2020
class BisectPhase(Enum):
@@ -322,12 +322,7 @@ def save(
322322

323323
# Write JSON
324324
with open(save_path, "w") as f:
325-
f.write(
326-
orjson.dumps(
327-
state.to_dict(),
328-
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS,
329-
).decode()
330-
)
325+
f.write(dumps(state.to_dict(), indent=True))
331326

332327
return save_path
333328

@@ -344,11 +339,11 @@ def load(path: str) -> BisectState:
344339
345340
Raises:
346341
FileNotFoundError: If state file doesn't exist.
347-
orjson.JSONDecodeError: If file is not valid JSON.
342+
JSONDecodeError: If file is not valid JSON.
348343
ValueError: If state data is invalid.
349344
"""
350345
with open(path, "r") as f:
351-
data = orjson.loads(f.read())
346+
data = loads(f.read())
352347
return BisectState.from_dict(data)
353348

354349
@staticmethod

tritonparse/diff/core/event_matcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import defaultdict
1111
from typing import Any
1212

13-
import orjson
13+
from tritonparse._json_compat import dumps, loads
1414
from tritonparse.tools.prettify_ndjson import load_ndjson
1515

1616

@@ -89,9 +89,9 @@ def ensure_source_mappings(event: dict[str, Any]) -> dict[str, Any]:
8989
# Reuse parse module function to generate source_mappings
9090
from tritonparse.parse.trace_processor import parse_single_trace_content
9191

92-
event_str = orjson.dumps(event, option=orjson.OPT_NON_STR_KEYS).decode()
92+
event_str = dumps(event)
9393
parsed_str = parse_single_trace_content(event_str)
94-
return orjson.loads(parsed_str.strip())
94+
return loads(parsed_str.strip())
9595

9696

9797
def match_events_by_index(

tritonparse/diff/output/event_writer.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dataclasses import asdict
1515
from typing import Any
1616

17-
import orjson
17+
from tritonparse._json_compat import dumps
1818
from tritonparse.diff.core.diff_types import (
1919
CompilationDiffResult,
2020
PythonLineDiff,
@@ -173,12 +173,7 @@ def append_diff_to_file(file_path: str, diff_event: dict[str, Any]) -> None:
173173
diff_event: The compilation_diff event to append.
174174
"""
175175
with open(file_path, "a") as f:
176-
f.write(
177-
orjson.dumps(
178-
_sanitize_non_finite_floats(diff_event), option=orjson.OPT_NON_STR_KEYS
179-
).decode()
180-
+ "\n"
181-
)
176+
f.write(dumps(_sanitize_non_finite_floats(diff_event)) + "\n")
182177

183178

184179
class ConsolidatedDiffWriter:
@@ -297,22 +292,10 @@ def write(self, output_path: str | None = None) -> None:
297292
with open(path, "w") as f:
298293
# Write all unique compilation events first
299294
for event in self._events.values():
300-
f.write(
301-
orjson.dumps(
302-
_sanitize_non_finite_floats(event),
303-
option=orjson.OPT_NON_STR_KEYS,
304-
).decode()
305-
+ "\n"
306-
)
295+
f.write(dumps(_sanitize_non_finite_floats(event)) + "\n")
307296
# Then write all diff events
308297
for diff in self._diffs:
309-
f.write(
310-
orjson.dumps(
311-
_sanitize_non_finite_floats(diff),
312-
option=orjson.OPT_NON_STR_KEYS,
313-
).decode()
314-
+ "\n"
315-
)
298+
f.write(dumps(_sanitize_non_finite_floats(diff)) + "\n")
316299

317300
@property
318301
def event_count(self) -> int:

tritonparse/info/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,11 @@ def _parse_value(value: str) -> Any:
169169

170170
# Check for list (e.g., "[3024, 10752]")
171171
if value.startswith("[") and value.endswith("]"):
172-
import orjson
172+
from tritonparse._json_compat import JSONDecodeError, loads
173173

174174
try:
175-
return orjson.loads(value)
176-
except orjson.JSONDecodeError:
175+
return loads(value)
176+
except JSONDecodeError:
177177
return value
178178

179179
# Try to convert to int

tritonparse/parse/common.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pathlib import Path
1010
from typing import Any, Dict, List, Optional, Tuple
1111

12-
import orjson
12+
from tritonparse._json_compat import dumps
1313
from tritonparse.shared_vars import (
1414
DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER as LOG_PREFIX,
1515
is_fbcode,
@@ -518,11 +518,7 @@ def parse_logs(
518518
# Save file mapping to parsed_log_dir
519519
log_file_list_path = os.path.join(parsed_log_dir, "log_file_list.json")
520520
with open(log_file_list_path, "w") as f:
521-
f.write(
522-
orjson.dumps(
523-
file_mapping, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS
524-
).decode()
525-
)
521+
f.write(dumps(file_mapping, indent=True))
526522

527523
# NOTICE: this print is required for tlparser-tritonparse integration
528524
# DON'T REMOVE THIS PRINT

0 commit comments

Comments
 (0)