Skip to content

fix: preserve original dtype in Upsample instead of hardcoding bfloat16#233

Open
Mr-Neutr0n wants to merge 1 commit into
deepseek-ai:mainfrom
Mr-Neutr0n:fix/upsample-preserve-dtype
Open

fix: preserve original dtype in Upsample instead of hardcoding bfloat16#233
Mr-Neutr0n wants to merge 1 commit into
deepseek-ai:mainfrom
Mr-Neutr0n:fix/upsample-preserve-dtype

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Summary

The Upsample.forward method in janus/models/vq_model.py hardcodes a cast to torch.bfloat16 after interpolation when the input is not float32:

x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
    torch.bfloat16
)

This causes a problem when the input tensor is float16 (common on GPUs without native bfloat16 support, e.g. older NVIDIA architectures). The tensor is silently converted from float16 to bfloat16 after interpolation, leading to dtype mismatches with downstream convolution layers that still expect float16 weights.

Fix

Store the original dtype before casting to float32 for interpolation, then cast back to it afterward:

orig_dtype = x.dtype
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
    orig_dtype
)

This preserves whatever dtype the input originally had (float16, bfloat16, etc.) instead of always forcing bfloat16.

Test plan

  • Verified that the fix correctly preserves float16 input dtype through the upsample operation
  • Verified that bfloat16 inputs continue to work as before (orig_dtype would be bfloat16, matching the previous behavior)
  • No functional change for float32 inputs (the else branch is unchanged)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant