Skip to content

Commit fe2aede

Browse files
Update on "[Bugfix] Allow non Module in TensorDictModule when method is passed"
relax this check as [LLM](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py#L53) in vLLM does not subclass nn.Module [ghstack-poisoned]
1 parent 62b5dc9 commit fe2aede

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

tensordict/nn/common.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,14 +1072,12 @@ def __init__(
10721072
if method is None:
10731073
if type(module) is type or not callable(module):
10741074
raise ValueError(
1075-
f"Module {module} if type {type(module)} is not callable. "
1075+
f"Module {module} of type {type(module)} is not callable. "
10761076
f"Typical accepted types are nn.Module or TensorDictModule."
10771077
)
10781078
else:
10791079
if not (hasattr(module, method) and callable(getattr(module, method))):
1080-
raise ValueError(
1081-
f"Module {module} does not have a callable method {method}. "
1082-
)
1080+
raise ValueError(f"Module {module} does not have a method {method}. ")
10831081
self.out_keys = out_keys
10841082
self.in_keys = in_keys
10851083

test/test_nn.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,21 @@ def my_func(self, tensor: torch.Tensor, *, an_integer: int):
295295
td = s(TensorDict(a=0))
296296

297297
assert td["b"] == 4
298-
298+
299299
def test_non_module_method(self):
300-
class MyNet:
300+
class NonModuleNet:
301+
def __init__(self):
302+
self.foo = 2
303+
301304
def my_func(self, tensor: torch.Tensor, *, an_integer: int):
302305
return tensor + an_integer
303-
306+
304307
s = TensorDictSequential(
305308
{
306309
"a": lambda td: td + 1,
307310
"b": lambda td: td * 2,
308311
"c": TensorDictModule(
309-
MyNet(),
312+
NonModuleNet(),
310313
in_keys=["a"],
311314
out_keys=["b"],
312315
method="my_func",
@@ -319,6 +322,15 @@ def my_func(self, tensor: torch.Tensor, *, an_integer: int):
319322

320323
assert td["b"] == 4
321324

325+
with pytest.raises(ValueError, match="does not have a method foo."):
326+
TensorDictModule(
327+
NonModuleNet(),
328+
in_keys=["a"],
329+
out_keys=["b"],
330+
method="foo",
331+
method_kwargs={"an_integer": 2},
332+
),
333+
322334
def test_mutable_sequence(self):
323335
in_keys = self.MyMutableSequence(["a", "b", "c"])
324336
out_keys = self.MyMutableSequence(["d", "e", "f"])

0 commit comments

Comments
 (0)