Skip to content

Commit 7467727

Browse files
r3t2facebook-github-bot
authored andcommitted
Use in-place ops in _quantize_affine_float8 to reduce peak memory
Summary: [torchao] Use in-place ops in _quantize_affine_float8 to reduce peak memory `_quantize_affine_float8` allocated up to 3 separate float32 copies of the input tensor (via `.to()`, `/`, and `.clamp()`). For large activations this caused unnecessary memory pressure and OOM. Switch to in-place `div_()` and `clamp_()` so only a single float32 copy is ever live. Use `copy=True` on the `.to()` call to guarantee a fresh buffer even when the input is already float32, preventing mutation of the caller's tensor. Differential Revision: D96350390
1 parent 95d366c commit 7467727

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

torchao/quantization/quant_primitives.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,15 +2327,20 @@ def _quantize_affine_float8(
23272327
"""
23282328
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
23292329
"""
2330-
tensor_fp32 = tensor.to(torch.float32)
2330+
# copy=True guarantees a fresh tensor even when the input is already fp32,
2331+
# so the in-place div_/clamp_ below never mutate the caller's tensor.
2332+
tensor_fp32 = tensor.to(torch.float32, copy=True)
23312333

23322334
# Expand scale to match tensor dimensions for block-wise quantization
23332335
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)
23342336

2335-
tensor_scaled = tensor_fp32 / scale_expanded
2337+
# Use in-place ops to avoid allocating additional float32 copies of the
2338+
# full tensor. This reduces peak memory from 3x to 1x the float32
2339+
# tensor size — critical for large activations (e.g. video VAE decode).
2340+
tensor_fp32.div_(scale_expanded)
23362341
max_value = torch.finfo(float8_dtype).max
2337-
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
2338-
return _RoundToFloat8.apply(tensor_clamped, float8_dtype)
2342+
tensor_fp32.clamp_(min=-max_value, max=max_value)
2343+
return _RoundToFloat8.apply(tensor_fp32, float8_dtype)
23392344

23402345

23412346
def _dequantize_affine_float8(

0 commit comments

Comments
 (0)