Skip to content

Commit 34b8f14

Browse files
author
Vincent Moens
committed
[BugFix] Fix TDParams compatibility with export
ghstack-source-id: 329e30d Pull Request resolved: #1285
1 parent 6d8119c commit 34b8f14

2 files changed

Lines changed: 63 additions & 0 deletions

File tree

tensordict/nn/params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,10 @@ def __getattr__(self, item: str) -> Any:
798798
try:
799799
return getattr(self.__dict__["_param_td"], item)
800800
except AttributeError:
801+
# During some state-dict loads, we may encounter cases where pytorch does a getattr
802+
# with the module name
803+
if item in self.keys():
804+
return TensorDictParams(self[item])
801805
return super().__getattr__(item)
802806
else:
803807
return super().__getattr__(item)

test/test_compile.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,65 @@ def test_export_seq(self):
878878
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
879879
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))
880880

881+
# This tests passes but there are various things that need to be fixed:
882+
# - we cannot use vmap directly
883+
# - if we use strict=True, there's an error due to the fact that export ignores
884+
# the replacement of the params (ie, params are still on "meta" and the values
885+
# after the call on the exported module don't match the original ones).
886+
# Currently only works with strict=False, because export fails to see that
887+
# the params in the module have changed and are not 'meta' anymore => this
888+
# is symptomatic of export failing to see the functional call
889+
@pytest.mark.parametrize("strict", [False]) # , True])
890+
def test_export_with_td_params(self, strict):
891+
module = torch.nn.Sequential(
892+
torch.nn.Linear(3, 4),
893+
torch.nn.Linear(4, 5),
894+
)
895+
module_td = TensorDictParams(
896+
TensorDict.from_module(module).data.expand(2).clone()
897+
)
898+
assert all(
899+
isinstance(p, torch.nn.Parameter) for p in module_td.values(True, True)
900+
)
901+
902+
class MyModule(torch.nn.Module):
903+
def __init__(self, td_params):
904+
super().__init__()
905+
self.tdparams = td_params
906+
self.arch = torch.nn.Sequential(
907+
torch.nn.Linear(3, 4, device="meta"),
908+
torch.nn.Linear(4, 5, device="meta"),
909+
)
910+
911+
def forward(self, x):
912+
# vmap with params currently fails
913+
# return torch.vmap(self.batch_forward, (0, None))(self.tdparams, x)
914+
return torch.stack(
915+
[self.batch_forward(p, x) for p in self.tdparams.unbind(0)]
916+
)
917+
918+
def batch_forward(self, params, x):
919+
with params.to_module(self.arch):
920+
return self.arch(x)
921+
# This could be an option but dynamo doesn't know how to trace through state_dict ops
922+
# sd = self.arch.state_dict()
923+
# try:
924+
# self.arch.load_state_dict(params.flatten_keys().to_dict(), assign=True)
925+
# return self.arch(x)
926+
# finally:
927+
# self.arch.load_state_dict(sd, assign=True)
928+
929+
m = MyModule(module_td)
930+
x = torch.randn(3)
931+
assert m(x).shape == (2, 5)
932+
exported_module = torch.export.export(
933+
m,
934+
args=(),
935+
kwargs={"x": x},
936+
strict=strict,
937+
)
938+
torch.testing.assert_close(exported_module.module()(x=x), m(x))
939+
881940

882941
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
883942
class TestONNXExport:

0 commit comments

Comments
 (0)