@@ -36,7 +36,7 @@ def test_transfusion(
3636 dim_latent = (384 , 192 ), # specify multiple latent dimensions
3737 modality_default_shape = ((32 ,), (64 ,)),
3838 transformer = dict (
39- dim = 512 ,
39+ dim = 64 ,
4040 depth = 2 ,
4141 use_flex_attn = use_flex_attn
4242 )
@@ -80,7 +80,7 @@ def test_auto_modality_transform(
8080 channel_first_latent = True ,
8181 modality_default_shape = (32 ,),
8282 transformer = dict (
83- dim = 512 ,
83+ dim = 64 ,
8484 depth = 2 ,
8585 use_flex_attn = use_flex_attn
8686 )
@@ -117,7 +117,7 @@ def test_text(
117117 channel_first_latent = True ,
118118 modality_default_shape = (32 ,),
119119 transformer = dict (
120- dim = 512 ,
120+ dim = 64 ,
121121 depth = 2 ,
122122 use_flex_attn = use_flex_attn
123123 )
@@ -141,7 +141,7 @@ def test_modality_only(
141141 channel_first_latent = channel_first ,
142142 modality_default_shape = (32 ,),
143143 transformer = dict (
144- dim = 512 ,
144+ dim = 64 ,
145145 depth = 2 ,
146146 use_flex_attn = False
147147 )
@@ -173,8 +173,8 @@ def test_text_image_end_to_end(
173173 modality_encoder = mock_vae_encoder ,
174174 modality_decoder = mock_vae_decoder ,
175175 transformer = dict (
176- dim = 512 ,
177- depth = 8
176+ dim = 64 ,
177+ depth = 2
178178 )
179179 )
180180
@@ -196,24 +196,26 @@ def test_text_image_end_to_end(
196196
197197 # allow researchers to experiment with different time distributions across multiple modalities in a sample
198198
199- def modality_length_to_times (modality_length ):
200- has_modality = modality_length > 0
201- return torch .where (has_modality , torch .ones_like (modality_length ), 0. )
199+ def num_modalities_to_times (num_modalities ):
200+ batch = num_modalities .shape [0 ]
201+ device = num_modalities .device
202+ total_modalities = num_modalities .amax ().item ()
203+ return torch .ones ((batch , total_modalities ), device = device )
202204
203- time_fn = modality_length_to_times if custom_time_fn else None
205+ time_fn = num_modalities_to_times if custom_time_fn else None
204206
205207 # forward
206208
207209 loss = model (
208210 text_and_images ,
209- modality_length_to_times_fn = time_fn
211+ num_modalities_to_times_fn = time_fn
210212 )
211213
212214 loss .backward ()
213215
214216 # after much training
215217
216- one_multimodal_sample = model .sample ()
218+ one_multimodal_sample = model .sample (max_length = 128 )
217219
218220def test_velocity_consistency ():
219221 mock_encoder = nn .Conv2d (3 , 384 , 3 , padding = 1 )
@@ -228,7 +230,7 @@ def test_velocity_consistency():
228230 modality_encoder = mock_encoder ,
229231 modality_decoder = mock_decoder ,
230232 transformer = dict (
231- dim = 512 ,
233+ dim = 64 ,
232234 depth = 1
233235 )
234236 )
@@ -251,14 +253,9 @@ def test_velocity_consistency():
251253 ]
252254 ]
253255
254- def modality_length_to_times (modality_length ):
255- has_modality = modality_length > 0
256- return torch .where (has_modality , torch .ones_like (modality_length ), 0. )
257-
258256 loss , breakdown = model (
259257 text_and_images ,
260258 velocity_consistency_ema_model = ema_model ,
261- modality_length_to_times_fn = modality_length_to_times ,
262259 return_breakdown = True
263260 )
264261
@@ -275,7 +272,7 @@ def test_axial_pos_emb():
275272 add_pos_emb = True ,
276273 modality_num_dim = (2 , 1 ),
277274 transformer = dict (
278- dim = 512 ,
275+ dim = 64 ,
279276 depth = 8
280277 )
281278 )
@@ -295,7 +292,7 @@ def test_axial_pos_emb():
295292
296293 # after much training
297294
298- one_multimodal_sample = model .sample ()
295+ one_multimodal_sample = model .sample (max_length = 128 )
299296
300297# unet related
301298
0 commit comments