diff --git a/internal/backend/cpu/conv_gemm_amd64_test.go b/internal/backend/cpu/conv_gemm_amd64_test.go new file mode 100644 index 0000000..5c44e91 --- /dev/null +++ b/internal/backend/cpu/conv_gemm_amd64_test.go @@ -0,0 +1,125 @@ +//go:build amd64 && !goexperiment.simd + +package cpu + +import ( + "math" + "math/rand" + "testing" + + "golang.org/x/sys/cpu" +) + +// naiveColBufMatMul is an independent reference for matMulColBufFloat32: +// out[i*colHeight+j] = sum_k kernel[i*colWidth+k] * colBuf[j*colWidth+k]. +func naiveColBufMatMul(kernel, colBuf []float32, cOut, colHeight, colWidth int) []float32 { + out := make([]float32, cOut*colHeight) + for i := 0; i < cOut; i++ { + for j := 0; j < colHeight; j++ { + var s float32 + for k := 0; k < colWidth; k++ { + s += kernel[i*colWidth+k] * colBuf[j*colWidth+k] + } + out[i*colHeight+j] = s + } + } + return out +} + +// TestMatMulColBufGemmDispatch verifies the conv im2col GEMM routes profitable +// shapes through the SIMD gemm kernel (via a colBuf transpose) and keeps tiny +// depthwise-style shapes on scalar, with correct results either way. A sentinel +// proves the SIMD path is actually taken so the test is not vacuous. +func TestMatMulColBufGemmDispatch(t *testing.T) { + if !cpu.X86.HasAVX2 || !cpu.X86.HasFMA { + t.Skip("AVX2+FMA not available on this CPU") + } + r := rand.New(rand.NewSource(0x636f6e76)) // "conv" + + var called bool + withGemmF32(t, func(c, a, b []float32, m, k, n int) { + called = true + gemmAVX2F32(c, a, b, m, k, n) + }) + + cases := []struct { + name string + cOut, colHeight, colWidth int + wantGemm bool + }{ + {"regular conv routes to gemm", 64, 384, 288, true}, // 7.1M >= blockThreshold, colH >= 16 + {"depthwise cOut=1 stays scalar", 1, 24, 9, false}, // 216 < blockThreshold + {"narrow colHeight<16 stays scalar", 512, 8, 64, false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + called = false + kernel := randSliceF32(r, tc.cOut*tc.colWidth) + colBuf := randSliceF32(r, tc.colHeight*tc.colWidth) + out := make([]float32, tc.cOut*tc.colHeight) + for i := range out { + out[i] = 999.0 // poison + } + matMulColBufFloat32(out, kernel, colBuf, tc.cOut, tc.colHeight, tc.colWidth) + want := naiveColBufMatMul(kernel, colBuf, tc.cOut, tc.colHeight, tc.colWidth) + + if called != tc.wantGemm { + t.Errorf("gemm path taken = %v, want %v", called, tc.wantGemm) + } + var maxDiff float64 + for i := range want { + d := math.Abs(float64(out[i]-want[i])) / (1 + math.Abs(float64(want[i]))) + if d > maxDiff { + maxDiff = d + } + } + if maxDiff > 1e-4 { + t.Errorf("max rel diff %.3e exceeds 1e-4", maxDiff) + } + }) + } +} + +// BenchmarkMatMulColBuf compares the scalar conv im2col GEMM against the SIMD +// path (colBuf transpose + reused GEMM kernel) at representative regular-conv +// shapes from the BirdNET v2.4 model. +func BenchmarkMatMulColBuf(b *testing.B) { + shapes := []struct { + name string + cOut, colH, colW int + }{ + {"conv_864x96x108", 864, 96, 108}, + {"conv_288x384x72", 288, 384, 72}, + {"conv_1536x24x192", 1536, 24, 192}, + } + r := rand.New(rand.NewSource(11)) + for _, s := range shapes { + kernel := randSliceF32(r, s.cOut*s.colW) + colBuf := randSliceF32(r, s.colH*s.colW) + out := make([]float32, s.cOut*s.colH) + b.Run(s.name+"/scalar", func(b *testing.B) { + prev := gemmF32 + gemmF32 = nil + b.ResetTimer() + for i := 0; i < b.N; i++ { + matMulColBufFloat32(out, kernel, colBuf, s.cOut, s.colH, s.colW) + } + b.StopTimer() + gemmF32 = prev + }) + b.Run(s.name+"/simd", func(b *testing.B) { + if !cpu.X86.HasAVX2 || !cpu.X86.HasFMA { + b.Skip("AVX2+FMA not available") + } + prev := gemmF32 + gemmF32 = gemmAVX2F32 + b.ResetTimer() + for i := 0; i < b.N; i++ { + matMulColBufFloat32(out, kernel, colBuf, s.cOut, s.colH, s.colW) + } + b.StopTimer() + gemmF32 = prev + }) + } +} diff --git a/internal/backend/cpu/conv_helpers.go b/internal/backend/cpu/conv_helpers.go index 79da0db..62ada0f 100644 --- a/internal/backend/cpu/conv_helpers.go +++ b/internal/backend/cpu/conv_helpers.go @@ -1,5 +1,12 @@ package cpu +import "sync" + +// colBufTPool reuses the transposed im2col buffer (colBuf^T) fed to the SIMD GEMM +// across convolutions. The buffer is fully overwritten by transposeF32 every call, +// so a pooled (un-zeroed) buffer is safe and avoids a large alloc + zero per conv. +var colBufTPool = sync.Pool{New: func() any { s := []float32(nil); return &s }} + // conv_helpers.go — inner-loop helper functions for Conv2D and MaxPool2D. // // These helpers are extracted from the innermost loops to reduce cognitive @@ -11,6 +18,24 @@ package cpu // Computes output[i*colHeight+j] = sum_k kernel[i*colWidth+k] * col[j*colWidth+k] // for all i in [0, cOut) and j in [0, colHeight). func matMulColBufFloat32(outputData, kernelData, colBuf []float32, cOut, colHeight, colWidth int) { + // SIMD fast path: this is out = kernel[cOut,colWidth] @ colBuf^T[colWidth,colHeight]. + // Transpose colBuf so the reduction axis becomes the row axis, then reuse the + // vendored GEMM kernel. Guarded to profitable shapes (the kernel needs a full + // column tile); tiny depthwise-style calls (cOut=1, small colHeight) stay scalar. + if gemmF32 != nil && colHeight >= gemmMinCols && cOut*colWidth*colHeight >= blockThreshold { + need := colWidth * colHeight + p := colBufTPool.Get().(*[]float32) + if cap(*p) < need { + *p = make([]float32, need) + } else { + *p = (*p)[:need] + } + colBufT := *p + transposeF32(colBufT, colBuf, colHeight, colWidth) // fully overwrites colBufT + gemmF32(outputData, kernelData, colBufT, cOut, colWidth, colHeight) + colBufTPool.Put(p) + return + } for i := 0; i < cOut; i++ { kernelRow := kernelData[i*colWidth : i*colWidth+colWidth] for j := 0; j < colHeight; j++ { diff --git a/internal/backend/cpu/matmul_gemm.go b/internal/backend/cpu/matmul_gemm.go index a1d43e2..c0922ea 100644 --- a/internal/backend/cpu/matmul_gemm.go +++ b/internal/backend/cpu/matmul_gemm.go @@ -18,3 +18,15 @@ const gemmMinCols = 16 // build where the archsimd micro-kernel owns dispatch) the scalar path is used // unchanged. var gemmF32 func(c, a, b []float32, m, k, n int) + +// transposeF32 writes the [rows, cols] -> [cols, rows] transpose of src into dst: +// dst[c*rows+r] = src[r*cols+c]. Used to recast the conv im2col product +// out = kernel @ colBuf^T into the GEMM kernel's A @ B form. +func transposeF32(dst, src []float32, rows, cols int) { + for r := 0; r < rows; r++ { + base := r * cols + for c := 0; c < cols; c++ { + dst[c*rows+r] = src[base+c] + } + } +}