diff --git a/janus/models/vq_model.py b/janus/models/vq_model.py index 887b721..bf99054 100755 --- a/janus/models/vq_model.py +++ b/janus/models/vq_model.py @@ -416,8 +416,9 @@ def __init__(self, in_channels, with_conv): def forward(self, x): if x.dtype != torch.float32: + orig_dtype = x.dtype x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( - torch.bfloat16 + orig_dtype ) else: x = F.interpolate(x, scale_factor=2.0, mode="nearest")