@@ -878,6 +878,65 @@ def test_export_seq(self):
878878 out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
879879 torch .testing .assert_close (out .module ()(x = x , y = y ), tdm (x = x , y = y ))
880880
881+ # This tests passes but there are various things that need to be fixed:
882+ # - we cannot use vmap directly
883+ # - if we use strict=True, there's an error due to the fact that export ignores
884+ # the replacement of the params (ie, params are still on "meta" and the values
885+ # after the call on the exported module don't match the original ones).
886+ # Currently only works with strict=False, because export fails to see that
887+ # the params in the module have changed and are not 'meta' anymore => this
888+ # is symptomatic of export failing to see the functional call
889+ @pytest .mark .parametrize ("strict" , [False ]) # , True])
890+ def test_export_with_td_params (self , strict ):
891+ module = torch .nn .Sequential (
892+ torch .nn .Linear (3 , 4 ),
893+ torch .nn .Linear (4 , 5 ),
894+ )
895+ module_td = TensorDictParams (
896+ TensorDict .from_module (module ).data .expand (2 ).clone ()
897+ )
898+ assert all (
899+ isinstance (p , torch .nn .Parameter ) for p in module_td .values (True , True )
900+ )
901+
902+ class MyModule (torch .nn .Module ):
903+ def __init__ (self , td_params ):
904+ super ().__init__ ()
905+ self .tdparams = td_params
906+ self .arch = torch .nn .Sequential (
907+ torch .nn .Linear (3 , 4 , device = "meta" ),
908+ torch .nn .Linear (4 , 5 , device = "meta" ),
909+ )
910+
911+ def forward (self , x ):
912+ # vmap with params currently fails
913+ # return torch.vmap(self.batch_forward, (0, None))(self.tdparams, x)
914+ return torch .stack (
915+ [self .batch_forward (p , x ) for p in self .tdparams .unbind (0 )]
916+ )
917+
918+ def batch_forward (self , params , x ):
919+ with params .to_module (self .arch ):
920+ return self .arch (x )
921+ # This could be an option but dynamo doesn't know how to trace through state_dict ops
922+ # sd = self.arch.state_dict()
923+ # try:
924+ # self.arch.load_state_dict(params.flatten_keys().to_dict(), assign=True)
925+ # return self.arch(x)
926+ # finally:
927+ # self.arch.load_state_dict(sd, assign=True)
928+
929+ m = MyModule (module_td )
930+ x = torch .randn (3 )
931+ assert m (x ).shape == (2 , 5 )
932+ exported_module = torch .export .export (
933+ m ,
934+ args = (),
935+ kwargs = {"x" : x },
936+ strict = strict ,
937+ )
938+ torch .testing .assert_close (exported_module .module ()(x = x ), m (x ))
939+
881940
882941@pytest .mark .skipif (not _has_onnx , reason = "ONNX is not available" )
883942class TestONNXExport :
0 commit comments