diff --git a/janus/models/vq_model.py b/janus/models/vq_model.py index 887b721..86ddfdc 100755 --- a/janus/models/vq_model.py +++ b/janus/models/vq_model.py @@ -231,7 +231,7 @@ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): self.embedding.weight.data, p=2, dim=-1 ) if self.show_usage: - self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + self.register_buffer("codebook_used", torch.zeros(self.n_e)) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten @@ -416,8 +416,9 @@ def __init__(self, in_channels, with_conv): def forward(self, x): if x.dtype != torch.float32: + orig_dtype = x.dtype x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( - torch.bfloat16 + orig_dtype ) else: x = F.interpolate(x, scale_factor=2.0, mode="nearest")