Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(self, x):
return x


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tensor_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
Expand All @@ -57,7 +57,7 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
deprecated_tensor_subclass, enable_parametrization=True, **kwargs
),
filter_fn,
)
Expand Down
14 changes: 5 additions & 9 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)


def benchmark_model_with_warmup(func, x, N_WARMUP=3):
def benchmark_model_with_warmup(func, N_WARMUP=3):
benchmark_model(func, N_WARMUP, device_type="cuda")
return benchmark_model(func, 10, device_type="cuda")

Expand Down Expand Up @@ -106,18 +106,14 @@ def sparse_func():
else:
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")

dense_time = benchmark_model_with_warmup(dense_func, "dense.json.gz")
sparse_time = benchmark_model_with_warmup(sparse_func, "sparse.json.gz")
dense_time = benchmark_model_with_warmup(dense_func)
sparse_time = benchmark_model_with_warmup(sparse_func)

dense_func_c = torch.compile(dense_func, mode="max-autotune")
dense_time_c = benchmark_model_with_warmup(
dense_func_c, "dense_compile.json.gz"
)
dense_time_c = benchmark_model_with_warmup(dense_func_c)

sparse_func_c = torch.compile(sparse_func, mode="max-autotune")
sparse_time_c = benchmark_model_with_warmup(
sparse_func_c, "sparse_compile.json.gz"
)
sparse_time_c = benchmark_model_with_warmup(sparse_func_c)

torch._dynamo.reset()

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import triton

if int(triton.__version__.split(".")[0]) < 3:
raise "triton >= 3.0.0 is required to run this test"
raise RuntimeError("triton >= 3.0.0 is required to run this test")
except ImportError:
raise "triton and hqq required to run this benchmark"
raise RuntimeError("triton and hqq required to run this benchmark")

from io import StringIO

Expand Down
12 changes: 7 additions & 5 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
from torchao.quantization.quant_api import (
_float8_cutlass_quant,
_float8_cutlass_quant_sparse,
from torchao.quantization.quant_api import quantize_

raise ImportError(
"This benchmark is broken: _float8_cutlass_quant and _float8_cutlass_quant_sparse "
"were removed when AffineQuantizedTensor (AQT) was deleted. "
"See torchao/quantization/quantize_/workflows/float8/sparse_2x4_cutlass_float8_tensor.py "
"for the current API."
)
from torchao.sparsity.utils import create_semi_structured_tensor

dtype = torch.bfloat16
dtypeq_X = torch.float8_e4m3fn
Expand Down
13 changes: 8 additions & 5 deletions benchmarks/benchmark_sparse_conversion_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from torchao.ops import (
to_sparse_semi_structured_cutlass_sm9x_f8,
)
from torchao.quantization.quant_api import (
_float8_cutlass_quant,
_float8_cutlass_quant_sparse,
)
from torchao.sparsity.utils import create_semi_structured_tensor

raise ImportError(
"This benchmark is broken: _float8_cutlass_quant and _float8_cutlass_quant_sparse "
"were removed when AffineQuantizedTensor (AQT) was deleted. "
"See torchao/quantization/quantize_/workflows/float8/sparse_2x4_cutlass_float8_tensor.py "
"for the current API."
)

dtype = torch.bfloat16
dtypeq_X = torch.float8_e4m3fn
dtypeq_W = torch.float8_e4m3fn
Expand Down Expand Up @@ -117,7 +120,7 @@ def profile():
results.append(benchmark(m, 8192))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
df.to_csv("sparse_conversion_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))

# print("PROFILING")
Expand Down
10 changes: 4 additions & 6 deletions benchmarks/benchmark_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy

import torch

from torchao.prototype.uintx import (
uintx_affine_weight_only,
unpack_cpu,
raise ImportError(
"This benchmark is broken: torchao.prototype.uintx module no longer exists. "
"The uintx functionality has been moved. "
"See torchao/prototype/dtypes/uintx/bitpacking.py for unpack_cpu."
)
from torchao.quantization.quant_api import quantize_


class Linear16(torch.nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def run_benchmarks(shapes):
assert file_path.is_file()

# Format is (m, k, n)
shapes = list(csv.reader(open(file_path, "r")))[1:]
with open(file_path, "r") as f:
shapes = list(csv.reader(f))[1:]
# Turn into list of int tuples
shapes = list(map(lambda x: tuple(map(int, x)), shapes))

Expand Down
225 changes: 225 additions & 0 deletions issues.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Issues Found in TorchAO Inference Examples

## 1. `int4_weight_only.py` - dtype mismatch

**File:** `docs/source/examples/inference/int4_weight_only.py`

### Problem
The example creates a model with default dtype (float32), but uses `int4_packing_format="tile_packed_to_4d"` which requires bfloat16.

### Description
The `Int4TilePackedTo4dTensor` only supports bfloat16 weights. When running the example, users get:
```
Only bfloat16 is supported for Int4TilePackedTo4dTensor, got torch.float32
```

This happens in `torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py:112-114`:
```python
assert hp_tensor.dtype == torch.bfloat16, (
f"Only bfloat16 is supported for Int4TilePackedTo4dTensor, got {hp_tensor.dtype}"
)
```

### Environment
- Config: `Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")`
- Supported: bfloat16
- Unsupported: float32, float16

### Steps to Approach to Solve
Add `dtype=torch.bfloat16` to the model creation:
```python
model = nn.Sequential(nn.Linear(2048, 2048, device="cuda", dtype=torch.bfloat16))
```

---

## 2. `float8_dynamic_activation_int4_weight.py` - missing dependency

**File:** `docs/source/examples/inference/float8_dynamic_activation_int4_weight.py`

### Problem
The example requires the `mslk` package (Meta's quantization library) which is not installed by default and not mentioned in the example.

### Description
When running the example:
```
ImportError: Requires mslk >= 1.0.0
```

The `Float8DynamicActivationInt4WeightConfig` uses MSLK kernels for quantization. The check is in `torchao/quantization/quantize_/workflows/int4/int4_tensor.py:140`.

### Environment
- Package: MSLK (https://github.com/pytorch/MSLK)
- Version required: >= 1.0.0
- Not included in default TorchAO installation

### Steps to Approach to Solve
1. Document the dependency at the top of the example
2. Or provide an alternative config that works without MSLK

---

## 3. `float8_dynamic_activation_float8_weight.py` - hardware requirement

**File:** `docs/source/examples/inference/float8_dynamic_activation_float8_weight.py`

### Problem
Requires CUDA compute capability ≥8.9 (Ada/Hopper) or MI300+ or XPU. The example will fail on older GPUs without clear indication why.

### Description
When running on unsupported hardware:
```
Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+ or XPU.
```

This is a hardware constraint check in the float8 quantization code path.

### Environment
- Required compute capability: sm_89 (Ada) or higher
- Also supports: sm_90 (Hopper), MI300+, XPU
- Unsupported: sm_80 (Ampere - A100)

### Steps to Approach to Solve
Add a note at the top of the example indicating hardware requirements.

---

## 4. `uintx_weight_only.py` - missing dependency

**File:** `docs/source/examples/inference/uintx_weight_only.py`

### Problem
Requires `gemlite` package which is not installed by default and not documented in the example.

### Description
When running the example:
```
gemlite is required. Install with: pip install gemlite
```

This config is in the prototype/quantization module, indicating it's experimental and has external dependencies.

### Environment
- Package: gemlite
- Module: `torchao.prototype.quantization.UIntxWeightOnlyConfig`

### Steps to Approach to Solve
Document the dependency at the top of the example:
```python
# Requires: pip install gemlite
```

---

## 5. `int8_dynamic_activation_uintx_weight.py` - missing dependency

**File:** `docs/source/examples/inference/int8_dynamic_activation_uintx_weight.py`

### Problem
Same as #4 - requires `gemlite` package not documented in the example.

### Description
When running the example:
```
gemlite is required. Install with: pip install gemlite
```

### Environment
- Package: gemlite
- Module: `torchao.prototype.quantization.Int8DynamicActivationUIntxWeightConfig`

### Steps to Approach to Solve
Document the dependency at the top of the example.

---

## 6. `nvfp4_dynamic_activation_nvfp4_weight.py` - NVFP4 dynamic mode requires Blackwell architecture

**File:** `docs/source/examples/inference/nvfp4_dynamic_activation_nvfp4_weight.py`

### Problem
Running this example on non-Blackwell GPU hardware fails with an unhelpful assertion error:
```
AssertionError: NVFP4 DYNAMIC mode is only supported on sm100+ machines
```

### Description
NVFP4 (NVIDIA 4-bit Floating Point) is a 4-bit floating point format using E2M1 encoding with block scaling. The example uses `NVFP4DynamicActivationNVFP4WeightConfig` with default settings, which invokes the "DYNAMIC" quantization path (when `step=None`).

The dynamic quantization path requires Blackwell architecture (sm100) because it computes activation scales at runtime. This is implemented in `torchao/prototype/mx_formats/inference_workflow.py:309-311`:

```python
elif step is None:
# Dynamic quantization
assert is_sm_at_least_100(), (
"NVFP4 DYNAMIC mode is only supported on sm100+ machines"
)
```

The error message "sm100+" is cryptic and provides no context about what it means or which GPUs are affected.

### Environment
- **NVFP4 format:** NVIDIA 4-bit floating point with E2M1 encoding
- **Supported GPUs:** sm100 (Blackwell - B200, GB200)
- **Unsupported GPUs:**
- sm_80: Ampere (A100)
- sm_90: Hopper (H100)
- sm_89: Ada (RTX 4090)

### Steps to Approach to Solve
1. Add hardware requirement documentation at the top of the example:
```python
# Requires NVIDIA Blackwell (sm100) architecture
# NOT supported on: Ampere (A100), Hopper (H100), Ada (RTX 4090)
```

2. Add a pre-check with informative error message before the assertion fails

3. Consider offering a fallback path for non-Blackwell GPUs using static quantization (PREPARE/CONVERT steps)

### Contrast with `nvfp4_weight_only.py`
The sibling file `nvfp4_weight_only.py` works on all GPUs because it uses weight-only quantization (static path via PREPARE/CONVERT steps) which does not require runtime dynamic activation scaling.

---

## 7. Missing Copyright Headers

Some example files have copyright headers while others don't, causing inconsistency.

### Problem
Inconsistent licensing headers across example files.

### Description
Only 3 files have copyright headers:
- `int8_dynamic_activation_int4_weight.py`
- `uintx_weight_only.py`
- `int8_dynamic_activation_uintx_weight.py`

All other 12 files lack headers.

### Environment
All inference examples in `docs/source/examples/inference/`

### Steps to Approach to Solve
Add consistent copyright headers to all files:
```python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
```

---

## Summary Table

| Example File | Issue | Severity |
|-------------|-------|----------|
| `int4_weight_only.py` | dtype mismatch (float32 vs bfloat16) | High |
| `float8_dynamic_activation_int4_weight.py` | missing `mslk` dependency | High |
| `float8_dynamic_activation_float8_weight.py` | hardware ≥sm_89 required | Medium |
| `uintx_weight_only.py` | missing `gemlite` dependency | High |
| `int8_dynamic_activation_uintx_weight.py` | missing `gemlite` dependency | High |
| `nvfp4_dynamic_activation_nvfp4_weight.py` | hardware ≥sm_100 (Blackwell) required | Medium |
| All files | Inconsistent copyright headers | Low |
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def forward(self, x):
test = model_qc(x).detach()

assert SQNR(ref_f, test) > min_sqnr, (
f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
f"got sqnr: {SQNR(ref_f, test)}, expected: {min_sqnr}"
)
self.assertTrue(torch.equal(ref_q, test))

Expand Down
Loading