Skip to content

Commit d4799c8

Browse files
committed
auto move tensors within modality sample to the correct device
1 parent 1eb18f5 commit d4799c8

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.29"
3+
version = "0.0.30"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,10 @@ def forward(
826826
assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
827827
assert self.dim_latents[modality_type] == modality_tensor.shape[-1], f'mismatch for modality latent dimension - expected {self.dim_latents[modality_type]} but received {modality_tensor.shape[-1]}'
828828

829+
# auto move modality tensor to device of model
830+
831+
modality_tensor = modality_tensor.to(device)
832+
829833
length = modality_tensor.shape[0]
830834

831835
# handle text

0 commit comments

Comments
 (0)