Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions internal/backend/cpu/conv_gemm_amd64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//go:build amd64 && !goexperiment.simd

package cpu

import (
"math"
"math/rand"
"testing"

"golang.org/x/sys/cpu"
)

// naiveColBufMatMul is an independent reference for matMulColBufFloat32:
// out[i*colHeight+j] = sum_k kernel[i*colWidth+k] * colBuf[j*colWidth+k].
func naiveColBufMatMul(kernel, colBuf []float32, cOut, colHeight, colWidth int) []float32 {
out := make([]float32, cOut*colHeight)
for i := 0; i < cOut; i++ {
for j := 0; j < colHeight; j++ {
var s float32
for k := 0; k < colWidth; k++ {
s += kernel[i*colWidth+k] * colBuf[j*colWidth+k]
}
out[i*colHeight+j] = s
}
}
return out
}

// TestMatMulColBufGemmDispatch verifies the conv im2col GEMM routes profitable
// shapes through the SIMD gemm kernel (via a colBuf transpose) and keeps tiny
// depthwise-style shapes on scalar, with correct results either way. A sentinel
// proves the SIMD path is actually taken so the test is not vacuous.
func TestMatMulColBufGemmDispatch(t *testing.T) {
if !cpu.X86.HasAVX2 || !cpu.X86.HasFMA {
t.Skip("AVX2+FMA not available on this CPU")
}
r := rand.New(rand.NewSource(0x636f6e76)) // "conv"

var called bool
withGemmF32(t, func(c, a, b []float32, m, k, n int) {
called = true
gemmAVX2F32(c, a, b, m, k, n)
})

cases := []struct {
name string
cOut, colHeight, colWidth int
wantGemm bool
}{
{"regular conv routes to gemm", 64, 384, 288, true}, // 7.1M >= blockThreshold, colH >= 16
{"depthwise cOut=1 stays scalar", 1, 24, 9, false}, // 216 < blockThreshold
{"narrow colHeight<16 stays scalar", 512, 8, 64, false},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
called = false
kernel := randSliceF32(r, tc.cOut*tc.colWidth)
colBuf := randSliceF32(r, tc.colHeight*tc.colWidth)
out := make([]float32, tc.cOut*tc.colHeight)
for i := range out {
out[i] = 999.0 // poison
}
matMulColBufFloat32(out, kernel, colBuf, tc.cOut, tc.colHeight, tc.colWidth)
want := naiveColBufMatMul(kernel, colBuf, tc.cOut, tc.colHeight, tc.colWidth)

if called != tc.wantGemm {
t.Errorf("gemm path taken = %v, want %v", called, tc.wantGemm)
}
var maxDiff float64
for i := range want {
d := math.Abs(float64(out[i]-want[i])) / (1 + math.Abs(float64(want[i])))
if d > maxDiff {
maxDiff = d
}
}
if maxDiff > 1e-4 {
t.Errorf("max rel diff %.3e exceeds 1e-4", maxDiff)
}
})
}
}

// BenchmarkMatMulColBuf compares the scalar conv im2col GEMM against the SIMD
// path (colBuf transpose + reused GEMM kernel) at representative regular-conv
// shapes from the BirdNET v2.4 model.
func BenchmarkMatMulColBuf(b *testing.B) {
shapes := []struct {
name string
cOut, colH, colW int
}{
{"conv_864x96x108", 864, 96, 108},
{"conv_288x384x72", 288, 384, 72},
{"conv_1536x24x192", 1536, 24, 192},
}
r := rand.New(rand.NewSource(11))
for _, s := range shapes {
kernel := randSliceF32(r, s.cOut*s.colW)
colBuf := randSliceF32(r, s.colH*s.colW)
out := make([]float32, s.cOut*s.colH)
b.Run(s.name+"/scalar", func(b *testing.B) {
prev := gemmF32
gemmF32 = nil
b.ResetTimer()
for i := 0; i < b.N; i++ {
matMulColBufFloat32(out, kernel, colBuf, s.cOut, s.colH, s.colW)
}
b.StopTimer()
gemmF32 = prev
})
b.Run(s.name+"/simd", func(b *testing.B) {
if !cpu.X86.HasAVX2 || !cpu.X86.HasFMA {
b.Skip("AVX2+FMA not available")
}
prev := gemmF32
gemmF32 = gemmAVX2F32
b.ResetTimer()
for i := 0; i < b.N; i++ {
matMulColBufFloat32(out, kernel, colBuf, s.cOut, s.colH, s.colW)
}
b.StopTimer()
gemmF32 = prev
})
}
}
25 changes: 25 additions & 0 deletions internal/backend/cpu/conv_helpers.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package cpu

import "sync"

// colBufTPool reuses the transposed im2col buffer (colBuf^T) fed to the SIMD GEMM
// across convolutions. The buffer is fully overwritten by transposeF32 every call,
// so a pooled (un-zeroed) buffer is safe and avoids a large alloc + zero per conv.
var colBufTPool = sync.Pool{New: func() any { s := []float32(nil); return &s }}

// conv_helpers.go — inner-loop helper functions for Conv2D and MaxPool2D.
//
// These helpers are extracted from the innermost loops to reduce cognitive
Expand All @@ -11,6 +18,24 @@
// Computes output[i*colHeight+j] = sum_k kernel[i*colWidth+k] * col[j*colWidth+k]
// for all i in [0, cOut) and j in [0, colHeight).
func matMulColBufFloat32(outputData, kernelData, colBuf []float32, cOut, colHeight, colWidth int) {
// SIMD fast path: this is out = kernel[cOut,colWidth] @ colBuf^T[colWidth,colHeight].
// Transpose colBuf so the reduction axis becomes the row axis, then reuse the
// vendored GEMM kernel. Guarded to profitable shapes (the kernel needs a full
// column tile); tiny depthwise-style calls (cOut=1, small colHeight) stay scalar.
if gemmF32 != nil && colHeight >= gemmMinCols && cOut*colWidth*colHeight >= blockThreshold {
need := colWidth * colHeight
p := colBufTPool.Get().(*[]float32)
if cap(*p) < need {
*p = make([]float32, need)
} else {
*p = (*p)[:need]

Check warning on line 31 in internal/backend/cpu/conv_helpers.go

View check run for this annotation

Codecov / codecov/patch

internal/backend/cpu/conv_helpers.go#L31

Added line #L31 was not covered by tests
}
colBufT := *p
transposeF32(colBufT, colBuf, colHeight, colWidth) // fully overwrites colBufT
gemmF32(outputData, kernelData, colBufT, cOut, colWidth, colHeight)
colBufTPool.Put(p)
return
}
for i := 0; i < cOut; i++ {
kernelRow := kernelData[i*colWidth : i*colWidth+colWidth]
for j := 0; j < colHeight; j++ {
Expand Down
12 changes: 12 additions & 0 deletions internal/backend/cpu/matmul_gemm.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ const gemmMinCols = 16
// build where the archsimd micro-kernel owns dispatch) the scalar path is used
// unchanged.
var gemmF32 func(c, a, b []float32, m, k, n int)

// transposeF32 writes the [rows, cols] -> [cols, rows] transpose of src into dst:
// dst[c*rows+r] = src[r*cols+c]. Used to recast the conv im2col product
// out = kernel @ colBuf^T into the GEMM kernel's A @ B form.
func transposeF32(dst, src []float32, rows, cols int) {
for r := 0; r < rows; r++ {
base := r * cols
for c := 0; c < cols; c++ {
dst[c*rows+r] = src[base+c]
}
}
}
Loading