Skip to content

Commit 6e6cc00

Browse files
kevmo314vmoens
andauthored
[Feature] Add free threading support (#1481)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent f6f1225 commit 6e6cc00

6 files changed

Lines changed: 64 additions & 28 deletions

File tree

.github/unittest/linux/scripts/setup_env.sh

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,28 @@ eval "$(${conda_dir}/bin/conda shell.bash hook)"
4040
printf "python: ${PYTHON_VERSION}\n"
4141
if [ ! -d "${env_dir}" ]; then
4242
printf "* Creating a test environment\n"
43-
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
43+
if [ "${PYTHON_VERSION}" == "3.14t" ]; then
44+
# Install free-threaded Python 3.14 from conda-forge
45+
conda create --prefix "${env_dir}" -y -c conda-forge python-freethreading
46+
# Set PYTHON_GIL=0 to keep GIL disabled
47+
export PYTHON_GIL=0
48+
else
49+
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
50+
fi
4451
fi
4552
conda activate "${env_dir}"
4653

54+
# For free-threaded Python, ensure PYTHON_GIL=0 is set
55+
if [ "${PYTHON_VERSION}" == "3.14t" ]; then
56+
export PYTHON_GIL=0
57+
fi
58+
4759
# 3. Install Conda dependencies
4860
printf "* Installing dependencies (except PyTorch)\n"
49-
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
61+
# Don't add python version constraint for free-threaded builds
62+
if [ "${PYTHON_VERSION}" != "3.14t" ]; then
63+
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
64+
fi
5065
cat "${this_dir}/environment.yml"
5166

5267
pip install pip --upgrade

.github/workflows/test-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
test-cpu:
5959
strategy:
6060
matrix:
61-
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
61+
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"]
6262
fail-fast: false
6363
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
6464
permissions:

tensordict/_td.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,8 +2834,6 @@ def _memmap_(
28342834
if inplace:
28352835
self._is_memmap = True
28362836
self._is_shared = False # since they are mutually exclusive
2837-
if self._validate_value_cached is not None:
2838-
delattr(self, "_validate_value_cached")
28392837
self._device = torch.device("cpu")
28402838
else:
28412839
dest._is_memmap = True

tensordict/base.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,6 @@ def __getstate__(self) -> dict[str, Any]:
661661
"_last_op",
662662
"_cache",
663663
"__lock_parents_weakrefs",
664-
"_validate_value_cached",
665664
):
666665
result.pop(key, None)
667666
return result
@@ -3503,8 +3502,6 @@ def dtype(self):
35033502
return self._dtype()
35043503

35053504
def _batch_size_setter(self, new_batch_size: torch.Size) -> None:
3506-
if self._validate_value_cached is not None:
3507-
delattr(self, "_validate_value_cached")
35083505
if new_batch_size == self.batch_size:
35093506
return
35103507
if self._lazy:
@@ -5768,17 +5765,13 @@ def clear_device_(self) -> Self:
57685765

57695766
"""
57705767
self._device = None
5771-
if self._validate_value_cached is not None:
5772-
delattr(self, "_validate_value_cached")
57735768
for value in self.values():
57745769
if _is_tensor_collection(type(value)):
57755770
value.clear_device_()
57765771
return self
57775772

57785773
def _set_device(self, device: torch.device) -> Self:
57795774
self._device = device
5780-
if self._validate_value_cached is not None:
5781-
delattr(self, "_validate_value_cached")
57825775
for value in self.values():
57835776
if _is_tensor_collection(type(value)):
57845777
value._set_device(device=device)
@@ -12964,26 +12957,21 @@ def _validate_key(self, key: NestedKey) -> NestedKey:
1296412957
raise KeyError(_GENERIC_NESTED_ERR.format(key))
1296512958
return key
1296612959

12967-
_validate_value_cached: str | None = None
12968-
1296912960
@property
1297012961
def _validate_value(self):
1297112962
if is_compiling():
1297212963
return self._validate_value_generic
12973-
_validate_value_cached = self._validate_value_cached
12974-
if _validate_value_cached is None:
12975-
if self.device:
12976-
if self.batch_size:
12977-
_validate_value_cached = "_validate_value_generic"
12978-
else:
12979-
_validate_value_cached = "_validate_value_batchfree"
12964+
if self.device:
12965+
if self.batch_size:
12966+
method_name = "_validate_value_generic"
1298012967
else:
12981-
if self.batch_size:
12982-
_validate_value_cached = "_validate_value_devicefree"
12983-
else:
12984-
_validate_value_cached = "_validate_value_batchfree_devicefree"
12985-
self._validate_value_cached = _validate_value_cached
12986-
return getattr(self, _validate_value_cached)
12968+
method_name = "_validate_value_batchfree"
12969+
else:
12970+
if self.batch_size:
12971+
method_name = "_validate_value_devicefree"
12972+
else:
12973+
method_name = "_validate_value_batchfree_devicefree"
12974+
return getattr(self, method_name)
1298712975

1298812976
def _validate_value_generic(
1298912977
self,

tensordict/csrc/pybind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
namespace py = pybind11;
1515

16-
PYBIND11_MODULE(_C, m) {
16+
PYBIND11_MODULE(_C, m, py::mod_gil_not_used()) {
1717
m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat
1818
m.def("unravel_key", &unravel_key, py::arg("key"));
1919
m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key"));

test/test_tensordict.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14127,6 +14127,41 @@ def test_memmap_robust_key_encoding_bijective(self):
1412714127
assert legacy == key
1412814128

1412914129

14130+
class TestFreeThreading:
14131+
"""Tests for free-threading (GIL-less Python) compatibility."""
14132+
14133+
def test_concurrent_gc_stress(self):
14134+
"""Regression test for free-threading race condition (PR #1481).
14135+
14136+
This test exercises concurrent access to TensorDict instances while
14137+
triggering garbage collection. On Python 3.14t with PYTHON_GIL=0,
14138+
this would previously cause segfaults due to a race condition in
14139+
the _validate_value_cached attribute.
14140+
14141+
The test passes on all Python versions but only catches the actual
14142+
race condition on free-threading builds.
14143+
"""
14144+
import threading
14145+
14146+
def gc_stress():
14147+
for _ in range(100):
14148+
td = TensorDict(
14149+
{"a": torch.randn(5, 5), "b": {"c": torch.randn(5, 3)}},
14150+
batch_size=[5],
14151+
)
14152+
# Access _validate_value to trigger the code path that had the race
14153+
_ = td._validate_value
14154+
td = None
14155+
gc.collect()
14156+
14157+
threads = [threading.Thread(target=gc_stress) for _ in range(8)]
14158+
for t in threads:
14159+
t.start()
14160+
for t in threads:
14161+
t.join()
14162+
# If we get here without segfault, the test passes
14163+
14164+
1413014165
if __name__ == "__main__":
1413114166
args, unknown = argparse.ArgumentParser().parse_known_args()
1413214167
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)