Skip to content

Commit ee3d62a

Browse files
authored
small fix for inference roofline model (#3990)
Update [ghstack-poisoned]
1 parent b8708a2 commit ee3d62a

3 files changed

Lines changed: 10 additions & 8 deletions

File tree

benchmarks/float8/float8_inference_roofline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def run(
581581
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
582582
)
583583
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
584+
print("bf16_ovhd_time_sympy", bf16_ovhd_time_sympy)
584585
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
585586
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
586587
print()

docs/source/workflows/inference.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ torch version 2.12.0.dev20260218+cu130
121121
torchao version 0.17.0+git3075bb624
122122
...
123123
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
124-
0 1024 1024 1024 0.64 0.94
125-
1 2048 2048 2048 1.75 1.21
126-
2 4096 4096 4096 1.90 1.45
127-
3 8192 8192 8192 1.94 1.75
124+
0 1024 1024 1024 1.00 0.93
125+
1 2048 2048 2048 1.75 1.20
126+
2 4096 4096 4096 1.90 1.46
127+
3 8192 8192 8192 1.94 1.76
128128
4 16384 16384 16384 1.97 1.77
129129

130130
#
@@ -137,11 +137,11 @@ torch version 2.12.0.dev20260218+cu130
137137
torchao version 0.17.0+git3075bb624
138138
...
139139
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
140-
0 1024 1024 1024 0.64 0.37
141-
1 2048 2048 2048 2.39 0.74
140+
0 1024 1024 1024 1.00 0.38
141+
1 2048 2048 2048 2.39 0.73
142142
2 4096 4096 4096 2.92 1.19
143-
3 8192 8192 8192 3.34 1.78
144-
4 16384 16384 16384 3.63 2.57
143+
3 8192 8192 8192 3.34 1.80
144+
4 16384 16384 16384 3.63 2.56
145145
```
146146

147147
## Other Available Quantization Techniques

torchao/testing/training/roofline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def get_inference_bf16_activation_mem_sympy(M, K, N, gpu_name: Optional[str] = N
561561
kernel_rw = BYTES_PER_EL_BF16 * M * K * 2
562562
# convert from bytes to seconds
563563
res_s = kernel_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
564+
res_s = sympy.Max(res_s, KERNEL_LAUNCH_OVERHEAD_SEC)
564565
return res_s
565566

566567

0 commit comments

Comments
 (0)