@@ -9,8 +9,8 @@ using CUDA
99 Configuration
1010=============================================================================#
1111
12- const NRUNS = 10
13- const WARMUP = 3
12+ const NRUNS = 20
13+ const WARMUP = 5
1414
1515#= ============================================================================
1616 Benchmark Utilities
@@ -20,20 +20,45 @@ struct BenchmarkResult
2020 name:: String
2121 min_ms:: Float64
2222 mean_ms:: Float64
23+ throughput:: String # e.g. "841 GB/s" or "43.1 TFLOPS" or ""
24+ end
25+
26+ function format_throughput (total, unit:: String , time_ms:: Float64 )
27+ if unit == " GB/s"
28+ gbps = total / (time_ms / 1000 ) / 1e9
29+ return " $(round (Int, gbps)) GB/s"
30+ elseif unit == " TFLOPS"
31+ tflops = total / (time_ms / 1000 ) / 1e12
32+ return " $(round (tflops, digits= 1 )) TFLOPS"
33+ elseif unit == " μs"
34+ return " $(round (Int, time_ms * 1000 )) μs"
35+ else
36+ return " "
37+ end
2338end
2439
2540function print_table (title:: String , results:: Vector{BenchmarkResult} )
2641 println ()
27- println (" =" ^ 60 )
42+ println (" =" ^ 72 )
2843 println (" " , title)
29- println (" =" ^ 60 )
30- println (rpad (" Implementation" , 20 ), rpad (" Min (ms)" , 12 ), " Mean (ms)" )
31- println (" -" ^ 60 )
44+ println (" =" ^ 72 )
45+ has_throughput = any (r -> ! isempty (r. throughput), results)
46+ if has_throughput
47+ println (rpad (" Implementation" , 20 ), rpad (" Min (ms)" , 12 ), rpad (" Mean (ms)" , 12 ), " Throughput" )
48+ else
49+ println (rpad (" Implementation" , 20 ), rpad (" Min (ms)" , 12 ), " Mean (ms)" )
50+ end
51+ println (" -" ^ 72 )
3252 for r in results
33- println (rpad (r. name, 20 ), rpad (round (r. min_ms, digits= 3 ), 12 ),
34- round (r. mean_ms, digits= 3 ))
53+ if has_throughput
54+ println (rpad (r. name, 20 ), rpad (round (r. min_ms, digits= 3 ), 12 ),
55+ rpad (round (r. mean_ms, digits= 3 ), 12 ), r. throughput)
56+ else
57+ println (rpad (r. name, 20 ), rpad (round (r. min_ms, digits= 3 ), 12 ),
58+ round (r. mean_ms, digits= 3 ))
59+ end
3560 end
36- println (" -" ^ 60 )
61+ println (" -" ^ 72 )
3762end
3863
3964#= ============================================================================
@@ -65,6 +90,12 @@ function run_benchmark(name::String)
6590 # Prepare data with benchmark=true for larger sizes
6691 data = @invokelatest mod. prepare (; benchmark= true )
6792
93+ # Get metric info if available
94+ metric_total, metric_unit = 0 , " "
95+ if isdefined (mod, :metric )
96+ metric_total, metric_unit = @invokelatest mod. metric (data)
97+ end
98+
6899 # Run cuTile
69100 result = @invokelatest mod. run (data; nruns= NRUNS, warmup= WARMUP)
70101
@@ -86,17 +117,17 @@ function run_benchmark(name::String)
86117 merge! (results, others)
87118 end
88119
89- return results
120+ return results, metric_total, metric_unit
90121end
91122
92123#= ============================================================================
93124 Main
94125=============================================================================#
95126
96127function main ()
97- println (" =" ^ 60 )
128+ println (" =" ^ 72 )
98129 println (" cuTile.jl Benchmarks" )
99- println (" =" ^ 60 )
130+ println (" =" ^ 72 )
100131 println ()
101132 println (" Configuration:" )
102133 println (" Runs: $NRUNS (+ $WARMUP warmup)" )
@@ -105,18 +136,21 @@ function main()
105136 for name in discover_benchmarks ()
106137 println (" \n Benchmarking $name ..." )
107138
108- results = run_benchmark (name)
109- if results === nothing
139+ ret = run_benchmark (name)
140+ if ret === nothing
110141 println (" (skipped - no prepare/run functions)" )
111142 continue
112143 end
113144
145+ results, metric_total, metric_unit = ret
146+
114147 # Convert to BenchmarkResult for printing
115148 benchmark_results = BenchmarkResult[]
116149 for (impl_name, times) in results
117150 min_t = minimum (times)
118151 mean_t = sum (times) / length (times)
119- push! (benchmark_results, BenchmarkResult (impl_name, min_t, mean_t))
152+ tp = ! isempty (metric_unit) ? format_throughput (metric_total, metric_unit, min_t) : " "
153+ push! (benchmark_results, BenchmarkResult (impl_name, min_t, mean_t, tp))
120154 end
121155
122156 # Sort by min time
@@ -126,9 +160,9 @@ function main()
126160 end
127161
128162 println ()
129- println (" =" ^ 60 )
163+ println (" =" ^ 72 )
130164 println (" Benchmark Complete" )
131- println (" =" ^ 60 )
165+ println (" =" ^ 72 )
132166end
133167
134168if abspath (PROGRAM_FILE ) == @__FILE__
0 commit comments