Skip to content

Commit aa5f34e

Browse files
authored
Update benchmarks scripts. (#151)
1 parent 0e8933d commit aa5f34e

File tree

15 files changed

+191
-48
lines changed

15 files changed

+191
-48
lines changed

README.md

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,29 @@ while the latter needs valid `CuArray`s to be passed to the kernel.
8888
Run benchmarks with:
8989

9090
```bash
91-
julia --project examples/benchmarks.jl # Julia
92-
uv run python examples/benchmarks.py # Python (for comparison)
91+
julia --project=examples examples/benchmarks.jl # Julia
92+
uv run python examples/benchmarks.py # Python (for comparison)
9393
```
9494

95-
Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080:
96-
97-
| Kernel | Julia | Python | Status |
98-
|--------|-------|--------|--------|
99-
| Vector Addition | 841 GB/s | 847 GB/s | OK (=) |
100-
| Matrix Transpose | 807 GB/s | 813 GB/s | OK (-1%) |
101-
| Layer Normalization | 653 GB/s | 758 GB/s | -14% |
102-
| Matrix Multiplication | 43.1 TFLOPS | 50.3 TFLOPS | -14% |
103-
| Batch Matrix Multiply | 30.4 TFLOPS | 40.0 TFLOPS | -24% |
104-
| FFT (3-stage Cooley-Tukey) | 620 μs | 486 μs | -28% |
95+
Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (20 runs, 5 warmup,
96+
min time reported):
97+
98+
| Kernel | Size | Julia | Python | Status |
99+
|--------|------|-------|--------|--------|
100+
| Vector Addition | 2^27 f32 | 841 GB/s | 847 GB/s | OK (=) |
101+
| Matrix Transpose | 8192² f32 | 773 GB/s | 817 GB/s | -5% |
102+
| Layer Normalization | 4096² f32 fwd | 615 GB/s | 761 GB/s | -19% |
103+
| Matrix Multiplication | 4096³ f32 | 47.6 TFLOPS | 50.2 TFLOPS | -5% |
104+
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 28.7 TFLOPS | 40.0 TFLOPS | -28% |
105+
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 465 μs | 486 μs | OK (+4%) |
106+
107+
With the same tileiras, all kernels compile to identical register counts, block sizes, and
108+
occupancy. The remaining gap is from **1→0 indexing overhead**: Julia's 1-based `bid()` and
109+
load indices generate extra `subi` ops in the Tile IR that perturb tileiras's SASS
110+
instruction scheduling (e.g. missing `.reuse` operand collector flags on HMMA, different
111+
address computation instruction selection). This affects all kernels proportional to loop
112+
count (layernorm 174 vs 128 IR lines across 3 loops; batchmatmul L1 hit 9.5% vs 41.3%
113+
from cascading scheduling differences).
105114

106115

107116
## Supported Operations

examples/batchmatmul.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ function verify(data, result)
104104
@assert isapprox(Array(result.C), expected; rtol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))"
105105
end
106106

107+
function metric(data)
108+
# 2*M*K*N*Batch FLOPs (multiply-add = 2 ops)
109+
return 2 * data.M * data.K * data.N * data.Batch, "TFLOPS"
110+
end
111+
107112
#=============================================================================
108113
Reference implementations for benchmarking
109114
=============================================================================#

examples/batchmatmul.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def run(data, *, tm: int = 128, tn: int = 128, tk: int = 64, nruns: int = 1, war
9393
return {"C": C, "times": times}
9494

9595

96+
def metric(data):
97+
"""Return (total_flops, unit) for throughput calculation."""
98+
# 2*M*K*N*Batch FLOPs (multiply-add = 2 ops)
99+
return 2 * data["M"] * data["K"] * data["N"] * data["Batch"], "TFLOPS"
100+
101+
96102
def verify(data, result):
97103
"""Verify batch matmul results."""
98104
A_np = cp.asnumpy(data["A"]).astype(np.float32)

examples/benchmarks.jl

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ using CUDA
99
Configuration
1010
=============================================================================#
1111

12-
const NRUNS = 10
13-
const WARMUP = 3
12+
const NRUNS = 20
13+
const WARMUP = 5
1414

1515
#=============================================================================
1616
Benchmark Utilities
@@ -20,20 +20,45 @@ struct BenchmarkResult
2020
name::String
2121
min_ms::Float64
2222
mean_ms::Float64
23+
throughput::String # e.g. "841 GB/s" or "43.1 TFLOPS" or ""
24+
end
25+
26+
function format_throughput(total, unit::String, time_ms::Float64)
27+
if unit == "GB/s"
28+
gbps = total / (time_ms / 1000) / 1e9
29+
return "$(round(Int, gbps)) GB/s"
30+
elseif unit == "TFLOPS"
31+
tflops = total / (time_ms / 1000) / 1e12
32+
return "$(round(tflops, digits=1)) TFLOPS"
33+
elseif unit == "μs"
34+
return "$(round(Int, time_ms * 1000)) μs"
35+
else
36+
return ""
37+
end
2338
end
2439

2540
function print_table(title::String, results::Vector{BenchmarkResult})
2641
println()
27-
println("=" ^ 60)
42+
println("=" ^ 72)
2843
println(" ", title)
29-
println("=" ^ 60)
30-
println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)")
31-
println("-" ^ 60)
44+
println("=" ^ 72)
45+
has_throughput = any(r -> !isempty(r.throughput), results)
46+
if has_throughput
47+
println(rpad("Implementation", 20), rpad("Min (ms)", 12), rpad("Mean (ms)", 12), "Throughput")
48+
else
49+
println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)")
50+
end
51+
println("-" ^ 72)
3252
for r in results
33-
println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12),
34-
round(r.mean_ms, digits=3))
53+
if has_throughput
54+
println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12),
55+
rpad(round(r.mean_ms, digits=3), 12), r.throughput)
56+
else
57+
println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12),
58+
round(r.mean_ms, digits=3))
59+
end
3560
end
36-
println("-" ^ 60)
61+
println("-" ^ 72)
3762
end
3863

3964
#=============================================================================
@@ -65,6 +90,12 @@ function run_benchmark(name::String)
6590
# Prepare data with benchmark=true for larger sizes
6691
data = @invokelatest mod.prepare(; benchmark=true)
6792

93+
# Get metric info if available
94+
metric_total, metric_unit = 0, ""
95+
if isdefined(mod, :metric)
96+
metric_total, metric_unit = @invokelatest mod.metric(data)
97+
end
98+
6899
# Run cuTile
69100
result = @invokelatest mod.run(data; nruns=NRUNS, warmup=WARMUP)
70101

@@ -86,17 +117,17 @@ function run_benchmark(name::String)
86117
merge!(results, others)
87118
end
88119

89-
return results
120+
return results, metric_total, metric_unit
90121
end
91122

92123
#=============================================================================
93124
Main
94125
=============================================================================#
95126

96127
function main()
97-
println("=" ^ 60)
128+
println("=" ^ 72)
98129
println(" cuTile.jl Benchmarks")
99-
println("=" ^ 60)
130+
println("=" ^ 72)
100131
println()
101132
println("Configuration:")
102133
println(" Runs: $NRUNS (+ $WARMUP warmup)")
@@ -105,18 +136,21 @@ function main()
105136
for name in discover_benchmarks()
106137
println("\nBenchmarking $name...")
107138

108-
results = run_benchmark(name)
109-
if results === nothing
139+
ret = run_benchmark(name)
140+
if ret === nothing
110141
println(" (skipped - no prepare/run functions)")
111142
continue
112143
end
113144

145+
results, metric_total, metric_unit = ret
146+
114147
# Convert to BenchmarkResult for printing
115148
benchmark_results = BenchmarkResult[]
116149
for (impl_name, times) in results
117150
min_t = minimum(times)
118151
mean_t = sum(times) / length(times)
119-
push!(benchmark_results, BenchmarkResult(impl_name, min_t, mean_t))
152+
tp = !isempty(metric_unit) ? format_throughput(metric_total, metric_unit, min_t) : ""
153+
push!(benchmark_results, BenchmarkResult(impl_name, min_t, mean_t, tp))
120154
end
121155

122156
# Sort by min time
@@ -126,9 +160,9 @@ function main()
126160
end
127161

128162
println()
129-
println("=" ^ 60)
163+
println("=" ^ 72)
130164
println(" Benchmark Complete")
131-
println("=" ^ 60)
165+
println("=" ^ 72)
132166
end
133167

134168
if abspath(PROGRAM_FILE) == @__FILE__

examples/benchmarks.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,52 @@
1212
# Configuration
1313
#=============================================================================
1414

15-
NRUNS = 10
16-
WARMUP = 3
15+
NRUNS = 20
16+
WARMUP = 5
1717

1818
#=============================================================================
1919
# Benchmark Utilities
2020
#=============================================================================
2121

2222
class BenchmarkResult:
23-
def __init__(self, name: str, min_ms: float, mean_ms: float):
23+
def __init__(self, name: str, min_ms: float, mean_ms: float, throughput: str = ""):
2424
self.name = name
2525
self.min_ms = min_ms
2626
self.mean_ms = mean_ms
27+
self.throughput = throughput
28+
29+
30+
def format_throughput(total, unit: str, time_ms: float) -> str:
31+
if unit == "GB/s":
32+
gbps = total / (time_ms / 1000) / 1e9
33+
return f"{gbps:.0f} GB/s"
34+
elif unit == "TFLOPS":
35+
tflops = total / (time_ms / 1000) / 1e12
36+
return f"{tflops:.1f} TFLOPS"
37+
elif unit == "μs":
38+
return f"{time_ms * 1000:.0f} μs"
39+
else:
40+
return ""
2741

2842

2943
def print_table(title: str, results: list):
3044
"""Print formatted benchmark results table."""
3145
print()
32-
print("=" * 60)
46+
print("=" * 72)
3347
print(f" {title}")
34-
print("=" * 60)
35-
print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)")
36-
print("-" * 60)
48+
print("=" * 72)
49+
has_throughput = any(r.throughput for r in results)
50+
if has_throughput:
51+
print(f"{'Implementation':<20}{'Min (ms)':<12}{'Mean (ms)':<12}Throughput")
52+
else:
53+
print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)")
54+
print("-" * 72)
3755
for r in results:
38-
print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}")
39-
print("-" * 60)
56+
if has_throughput:
57+
print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:<12.3f}{r.throughput}")
58+
else:
59+
print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}")
60+
print("-" * 72)
4061

4162

4263
#=============================================================================
@@ -76,6 +97,10 @@ def run_benchmark(name: str):
7697
# Prepare data with benchmark=True for larger sizes
7798
data = prepare_fn(benchmark=True)
7899

100+
# Get metric info if available
101+
metric_fn = getattr(mod, "metric", None)
102+
metric_total, metric_unit = (0, "") if not metric_fn else metric_fn(data)
103+
79104
# Run cuTile
80105
result = run_fn(data, nruns=NRUNS, warmup=WARMUP)
81106

@@ -96,7 +121,7 @@ def run_benchmark(name: str):
96121
others = run_others_fn(data, nruns=NRUNS, warmup=WARMUP)
97122
results.update(others)
98123

99-
return results
124+
return results, metric_total, metric_unit
100125

101126

102127
#=============================================================================
@@ -106,9 +131,9 @@ def run_benchmark(name: str):
106131
def main():
107132
import torch # For GPU name
108133

109-
print("=" * 60)
134+
print("=" * 72)
110135
print(" cuTile Python Benchmarks")
111-
print("=" * 60)
136+
print("=" * 72)
112137
print()
113138
print("Configuration:")
114139
print(f" Runs: {NRUNS} (+ {WARMUP} warmup)")
@@ -117,27 +142,30 @@ def main():
117142
for name in discover_benchmarks():
118143
print(f"\nBenchmarking {name}...")
119144

120-
results = run_benchmark(name)
121-
if results is None:
145+
ret = run_benchmark(name)
146+
if ret is None:
122147
print(" (skipped - no prepare/run functions)")
123148
continue
124149

150+
results, metric_total, metric_unit = ret
151+
125152
# Convert to BenchmarkResult for printing
126153
benchmark_results = []
127154
for impl_name, times in results.items():
128155
min_t = min(times)
129156
mean_t = sum(times) / len(times)
130-
benchmark_results.append(BenchmarkResult(impl_name, min_t, mean_t))
157+
tp = format_throughput(metric_total, metric_unit, min_t) if metric_unit else ""
158+
benchmark_results.append(BenchmarkResult(impl_name, min_t, mean_t, tp))
131159

132160
# Sort by min time
133161
benchmark_results.sort(key=lambda r: r.min_ms)
134162

135163
print_table(name, benchmark_results)
136164

137165
print()
138-
print("=" * 60)
166+
print("=" * 72)
139167
print(" Benchmark Complete")
140-
print("=" * 60)
168+
print("=" * 72)
141169

142170

143171
if __name__ == "__main__":

examples/fft.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ function verify(data, result)
266266
@assert isapprox(Array(result.output), reference, rtol=1e-4)
267267
end
268268

269+
function metric(data)
270+
# FFT is a latency benchmark; report time directly
271+
return 0, "μs"
272+
end
273+
269274
#=============================================================================
270275
Reference implementations for benchmarking
271276
=============================================================================#

examples/fft.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ def run(data, *, nruns: int = 1, warmup: int = 0):
176176
return {"output": output, "times": times}
177177

178178

179+
def metric(data):
180+
"""Return (0, unit) for FFT - latency benchmark, report time directly."""
181+
return 0, "μs"
182+
183+
179184
def verify(data, result):
180185
"""Verify FFT results."""
181186
reference = torch.fft.fft(data["input"], dim=-1)

examples/layernorm.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,11 @@ function test_layernorm(M, N, TILE_N; TILE_M::Int=32, eps::Float32=1f-5, name=no
385385
println(" fwd passed, bwd passed")
386386
end
387387

388+
function metric(data)
389+
# Forward: 3 reads of X + W + B reads + Y write + Mean/Rstd writes ≈ 4*M*N floats
390+
return 4 * data.M * data.N * sizeof(Float32), "GB/s"
391+
end
392+
388393
# No run_others for layernorm - no simple reference implementation to compare against
389394

390395
#=============================================================================

examples/layernorm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,12 @@ def verify(data, result):
254254
assert np.allclose(cp.asnumpy(result["DB"]), expected_DB, rtol=rtol, atol=atol), \
255255
f"DB mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DB']) - expected_DB))}"
256256

257+
def metric(data):
258+
"""Return (total_bytes, unit) for throughput calculation."""
259+
# Forward: 3 reads of X + W + B reads + Y write + Mean/Rstd writes ≈ 4*M*N floats
260+
return 4 * data["M"] * data["N"] * 4, "GB/s"
261+
262+
257263
# No run_others for layernorm - no simple reference implementation to compare against
258264

259265

examples/matmul.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ function verify(data, result)
105105
@assert isapprox(Array(result.C), expected; rtol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))"
106106
end
107107

108+
function metric(data)
109+
# 2*M*N*K FLOPs (multiply-add = 2 ops)
110+
return 2 * data.M * data.N * data.K, "TFLOPS"
111+
end
112+
108113
#=============================================================================
109114
Reference implementations for benchmarking
110115
=============================================================================#

0 commit comments

Comments
 (0)