Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
import pytest
import torch

from torchao.utils import torch_version_at_least

if not (
torch_version_at_least("2.7.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
pytest.skip("Requires CUDA capability >= 9.0", allow_module_level=True)

triton = pytest.importorskip("triton", reason="Triton required to run this test")

from packaging import version
Expand Down
16 changes: 13 additions & 3 deletions test/prototype/blockwise_fp8_training/test_blockwise_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
import torch
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.utils import is_sm_at_least_90
from torchao.utils import is_ROCM, torch_version_at_least

if not (
torch_version_at_least("2.7.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
pytest.skip("Requires CUDA capability >= 9.0", allow_module_level=True)

if is_ROCM():
pytest.skip(
"Blockwise FP8 linear has numerical issues on ROCm", allow_module_level=True
)

triton = pytest.importorskip("triton", reason="Triton required to run this test")
if not is_sm_at_least_90():
pytest.skip("This test requires SM90 or higher", allow_module_level=True)


from torchao.float8.float8_utils import compute_error
Expand Down
43 changes: 11 additions & 32 deletions torchao/prototype/blockwise_fp8_training/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def triton_fp8_blockwise_act_quant_lhs_kernel(
y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
tl.store(y_ptr + y_offs, y, mask=y_mask)

# Write reciprocal scales
# Write reciprocal scales (mask needed because NUM_GROUPS may exceed M,
# and unmasked writes to the column-major scale tensor corrupt later columns)
scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1
scale_mask = m_offs[:, None] < M
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
Expand Down Expand Up @@ -822,8 +823,8 @@ def torch_blockwise_scale_act_quant_rhs(
assert dtype in FP8_E4M3_DTYPES, f"dtype must be one of {FP8_E4M3_DTYPES}"

M, K = x.size()
max_fp8_e4m3 = torch.finfo(dtype).max
min_fp8_e4m3 = torch.finfo(dtype).min
fp8_max = torch.finfo(dtype).max
fp8_min = torch.finfo(dtype).min

# Reshape input to work with blocks of size (block_size, 1) along dimension 0
num_blocks_m = M // block_size
Expand All @@ -837,35 +838,18 @@ def torch_blockwise_scale_act_quant_rhs(

# Process each column (K dimension) separately
for k in range(K):
# Extract column k from all blocks: shape (num_blocks_m, block_size)
x_col = x_blocks[:, :, k] # (num_blocks_m, block_size)

# Compute absolute max for each block
amax = torch.abs(x_col).max(dim=1, keepdim=True)[0] # (num_blocks_m, 1)

# Clamp to avoid division by zero
x_col = x_blocks[:, :, k]
amax = torch.abs(x_col).max(dim=1, keepdim=True)[0]
amax = torch.clamp(amax, min=eps).to(torch.float64)

# Compute scales
scale = (max_fp8_e4m3 / amax).to(torch.float32) # (num_blocks_m, 1)

# Apply scaling
y_col = x_col * scale # (num_blocks_m, block_size)

# Clamp to FP8 range
y_col = torch.clamp(y_col, min=min_fp8_e4m3, max=max_fp8_e4m3)

# Store results
scale = (fp8_max / amax).to(torch.float32)
y_col = x_col * scale
y_col = torch.clamp(y_col, min=fp8_min, max=fp8_max)
y_blocks[:, :, k] = y_col.to(dtype)
scales[:, k] = scale.squeeze(-1) # (num_blocks_m,)

# Reshape back to original shape (removing padding if any)
y = y_blocks.view(-1, K)[:M, :] # (M, K)
scales[:, k] = scale.squeeze(-1)

# Convert to column-major format
y = y_blocks.view(-1, K)[:M, :]
y = y.t().contiguous().t()

# Return output tensor and reciprocal scales
return y, 1.0 / scales


Expand All @@ -878,16 +862,13 @@ def torch_blockwise_scale_weight_quant(x, tile_size=128, dtype=e4m3_dtype):
assert x.is_contiguous(), "input tensor must be contiguous"
height, width = x.shape

# Compute block sizes
t_h = height // tile_size
t_w = width // tile_size

# Reshape 2D input tensor into 4D tensor with shape (t_h, t_w, tile_size * tile_size)
x = x.reshape(t_h, tile_size, t_w, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(-1, tile_size * tile_size)

# Compute amax along last dim (i.e., the block)
x_amax = x.abs().max(dim=1).values.unsqueeze(1).to(torch.float64)
x_amax = torch.clamp(x_amax, min=EPS, max=float("inf"))

Expand All @@ -897,11 +878,9 @@ def torch_blockwise_scale_weight_quant(x, tile_size=128, dtype=e4m3_dtype):

x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(dtype)

# Reshape quantized output and scales back to 2D
x = x.reshape(t_h, t_w, tile_size, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(height, width)
s = s.reshape(t_h, t_w).to(torch.float)

# Return output tensor and reciprocal scale
return x, 1.0 / s
4 changes: 2 additions & 2 deletions torchao/prototype/blockwise_fp8_training/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import is_sm_at_least_90
from torchao.utils import is_ROCM, is_sm_at_least_90


class fp8_blockwise_mm(torch.autograd.Function):
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
assert dtype in self.supported_dtypes, (
f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}"
)
assert is_sm_at_least_90(), "Only support SM90"
assert is_sm_at_least_90() or is_ROCM(), "Requires CUDA SM >= 9.0 or ROCm"
self.block_size = block_size
self.dtype = dtype
self.use_triton = use_triton
Expand Down
129 changes: 102 additions & 27 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
"TorchAOBaseTensor",
"is_cuda_version_at_least",
"is_MI300",
"is_MI325X",
"is_MI350",
"is_Navi4",
"is_ROCM",
"is_rocm_gpu_at_least",
"rocm_device_capability",
"is_sm_at_least_89",
"is_sm_at_least_90",
"is_sm_at_least_100",
Expand Down Expand Up @@ -1159,43 +1165,112 @@ def fill_defaults(args, n, defaults_tail):
return r


# Supported AMD GPU Models and their LLVM gfx Codes:
# ---------------------------------------------------------------------------
# ROCm device capability helpers
#
# | AMD GPU Model | LLVM gfx Code |
# |---------------|------------------------|
# | Navi4 | gfx1200, gfx1201 |
# | MI300X | gfx940, gfx941, gfx942 |
# | MI350 | gfx950 |
# AMD GPU families and their LLVM GFX codes:
#
# | GPU Family | GFX Code | (major, minor) |
# |---------------|------------------------|----------------|
# | MI210 | gfx90a | (9, 0) |
# | MI300A/X | gfx940, gfx941, gfx942 | (9, 4) |
# | MI325X | gfx942 | (9, 4) |
# | MI350X | gfx950 | (9, 5) |
# | RDNA3 (Navi3) | gfx1100, gfx1101, ... | (11, 0) |
# | RDNA4 (Navi4) | gfx1200, gfx1201 | (12, 0) |
#
# The (major, minor) mapping mirrors HIP's hipDeviceProp_t derivation
# from the GFX arch string, following the convention established by
# vLLM (see: vllm/platforms/rocm.py).
# ---------------------------------------------------------------------------


def is_ROCM():
def is_ROCM() -> bool:
return torch.cuda.is_available() and torch.version.hip is not None


def is_MI300():
if is_ROCM():
mxArchName = ["gfx940", "gfx941", "gfx942"]
archName = torch.cuda.get_device_properties(0).gcnArchName
for arch in mxArchName:
if arch in archName:
return True
return False
def _parse_rocm_device_capability(gcn_arch: str) -> tuple[int, int] | None:
"""
Parse (major, minor) capability from a GCN architecture string.

Handles both 1-digit major (gfx9xx) and 2-digit major (gfx1xxx) formats:
gfx90a -> (9, 0)
gfx942 -> (9, 4)
gfx950 -> (9, 5)
gfx1100 -> (11, 0)
gfx1200 -> (12, 0)
"""
import re

m = re.match(r"gfx(\d+)", gcn_arch)
if not m:
return None

digits = m.group(1)
if len(digits) <= 3:
return (int(digits[0]), int(digits[1]))
elif len(digits) == 4:
return (int(digits[:2]), int(digits[2]))
return None


@functools.cache
def rocm_device_capability(device_id: int = 0) -> tuple[int, int] | None:
"""
Return (major, minor) device capability for a ROCm GPU, or None if
not running on ROCm. Caches the result after first call.
"""
if not is_ROCM():
return None
gcn_arch = torch.cuda.get_device_properties(device_id).gcnArchName
return _parse_rocm_device_capability(gcn_arch)


def is_MI300() -> bool:
"""MI300A/X: gfx940, gfx941, gfx942."""
if not is_ROCM():
return False
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(arch in gcn_arch for arch in ("gfx940", "gfx941", "gfx942"))

def is_MI350():
if is_ROCM():
archName = torch.cuda.get_device_properties(0).gcnArchName
if "gfx950" in archName:
return True
return False

def is_MI325X() -> bool:
"""MI325X: gfx942 (same GFX code as MI300X, differentiated by device ID)."""
return is_MI300()

def is_Navi4():
if is_ROCM():
archName = torch.cuda.get_device_properties(0).gcnArchName
if "gfx1200" or "gfx1201" in archName:
return True
return False

def is_MI350() -> bool:
"""MI350X: gfx950."""
if not is_ROCM():
return False
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return "gfx950" in gcn_arch


def is_Navi4() -> bool:
"""RDNA4 (Navi4): gfx1200, gfx1201."""
if not is_ROCM():
return False
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return "gfx1200" in gcn_arch or "gfx1201" in gcn_arch


def is_rocm_gpu_at_least(major: int, minor: int = 0) -> bool:
"""
Check if the ROCm GPU capability is at least (major, minor).

Allows writing capability gates that work across AMD GPU families,
analogous to is_sm_at_least_*() for NVIDIA GPUs.

Examples:
is_rocm_gpu_at_least(9, 4) # True on MI300X/MI325X/MI350X
is_rocm_gpu_at_least(9, 5) # True on MI350X only
is_rocm_gpu_at_least(12, 0) # True on RDNA4 (Navi4)
"""
cap = rocm_device_capability()
if cap is None:
return False
return cap >= (major, minor)


def is_sm_version(major: int, minor: int) -> bool:
Expand Down