Skip to content

[CI Fix] Propagate dtype in TorchAOBaseTensor._to_copy#4358

Closed
jainapurva wants to merge 3 commits intomainfrom
fix/propagate-dtype-in-to-copy
Closed

[CI Fix] Propagate dtype in TorchAOBaseTensor._to_copy#4358
jainapurva wants to merge 3 commits intomainfrom
fix/propagate-dtype-in-to-copy

Conversation

@jainapurva
Copy link
Copy Markdown
Contributor

Problem

PR #4297 added non_blocking propagation to TorchAOBaseTensor._to_copy, but introduced a bug: while _get_to_kwargs returns device, dtype, and non_blocking, the _to_copy handler only propagated device and non_blocking to inner tensors.

This meant that calls like tensor.to(dtype=torch.float16) or tensor.to(device='cuda', dtype=torch.bfloat16) would change the wrapper tensor's dtype but NOT the inner tensors (qdata, scale, etc.), causing a dtype mismatch between the wrapper and its data.

Fix

  • Pop dtype from kwargs and pass it to all inner .to() calls
  • Use explicit keyword arguments for clarity: device=device, dtype=dtype, non_blocking=non_blocking

This ensures all three parameters are consistently propagated to inner tensors when calling .to() on TorchAOBaseTensor subclasses.

Testing

Added test_to_copy_propagates_dtype_and_non_blocking to verify:

  • Dtype-only changes propagate correctly
  • Combined device + dtype + non_blocking changes work
  • All existing tests continue to pass

## Problem

PR #4297 added `non_blocking` propagation to TorchAOBaseTensor._to_copy, but
introduced a bug: while `_get_to_kwargs` returns `device`, `dtype`, and
`non_blocking`, the `_to_copy` handler only propagated `device` and
`non_blocking` to inner tensors.

This meant that calls like `tensor.to(dtype=torch.float16)` or
`tensor.to(device='cuda', dtype=torch.bfloat16)` would change the wrapper
tensor's dtype but NOT the inner tensors (qdata, scale, etc.), causing a
dtype mismatch between the wrapper and its data.

## Fix

- Pop `dtype` from kwargs and pass it to all inner `.to()` calls
- Use explicit keyword arguments for clarity: `device=device, dtype=dtype,
  non_blocking=non_blocking`

This ensures all three parameters are consistently propagated to inner
tensors when calling `.to()` on TorchAOBaseTensor subclasses.

## Testing

Added `test_to_copy_propagates_dtype_and_non_blocking` to verify:
- Dtype-only changes propagate correctly
- Combined device + dtype + non_blocking changes work
- All existing tests continue to pass

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@jainapurva jainapurva added the topic: bug fix Use this tag for PRs that fix bugs label Apr 30, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4358

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit 23df18a with merge base 28e6aca (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2026
@jainapurva jainapurva added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 30, 2026
jainapurva and others added 2 commits April 30, 2026 23:17
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…hange

## Problem

PyTorch nightly now implements saturated casting to float8_e4m3fn in eager
mode, matching the behavior that was previously only in compiled/triton mode.

The test `test_cast_to_float8_e4m3fn_saturation_behavior` was expecting the
old unsaturated behavior (out-of-range values → NaN), causing H100 CI failures.

## Fix

Updated the test to verify the new saturated casting behavior:
- Changed assertion from expecting NaN to expecting saturation
- Added verification that out-of-range values are clamped to max_val
- Updated assertions to verify eager and compiled modes produce identical results
- Updated comments to reflect the completed TODO from issue #1912

## Testing

This fixes the H100 test failures on main branch where the test was asserting:
```python
assert torch.all(torch.isnan(data_out_of_range_f8))  # Old behavior
```

But PyTorch now produces saturated values (448/-448) instead of NaN.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@jainapurva jainapurva closed this May 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant