Describe the bug
These objects cannot be outputted from forked functions. This is extremely unfortunate because, well, tensordicts are such a great way to handle multiple, heterogeneous blocks of data.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torch
from tensordict import tensorclass
@tensorclass
class MyClass:
x: torch.Tensor
y: torch.Tensor
def make_obj(x: torch.Tensor):
return MyClass(x=x, y=x + 1)
def parallel_class():
fut1 = torch.jit.fork(make_obj, torch.tensor(1))
fut2 = torch.jit.fork(make_obj, torch.tensor(2))
o1 = torch.jit.wait(fut1)
o2 = torch.jit.wait(fut2)
return o1, o2
print(parallel_class()) # ❌ Error
import torch
from tensordict import TensorDict
def make_td(x: torch.Tensor):
# returns a simple TensorDict
return TensorDict({"x": x, "y": x + 1}, batch_size=[])
def parallel():
# Attempt to fork two TensorDict-producing tasks
fut1 = torch.jit.fork(make_td, torch.tensor(1))
fut2 = torch.jit.fork(make_td, torch.tensor(2))
td1 = torch.jit.wait(fut1)
td2 = torch.jit.wait(fut2)
return td1, td2
print(parallel()) # ❌ Error
Expected behavior
Both of these functions should work as expected (as if the outputted objects were of type dict[str, Tensor] but instead raise the following error:
Only tensors and (possibly nested) tuples of tensors, lists, or dicts are supported as inputs or outputs of traced functions, but instead got value of type TensorDict.
System info
Describe the characteristic of your environment:
- Installed via UV
- Python 3.12
tensordict==0.10.0
Additional context
Tested on Mac and Linux, with both CPU and CUDA.
Reason and Possible fixes
N/A
Checklist
Describe the bug
These objects cannot be outputted from forked functions. This is extremely unfortunate because, well, tensordicts are such a great way to handle multiple, heterogeneous blocks of data.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
Expected behavior
Both of these functions should work as expected (as if the outputted objects were of type
dict[str, Tensor]but instead raise the following error:System info
Describe the characteristic of your environment:
tensordict==0.10.0Additional context
Tested on Mac and Linux, with both CPU and CUDA.
Reason and Possible fixes
N/A
Checklist