perf(cpu): route conv im2col GEMM through the vendored SIMD kernel#99
Conversation
matMulColBufFloat32 computes out = kernel[cOut,colWidth] @ colBuf^T. When the vendored GEMM kernel is wired in (gemmF32 != nil; amd64 AVX2+FMA, from born-ml#96), transpose colBuf to [colWidth,colHeight] and reuse gemmF32 instead of the scalar dot-product loop. Guarded to profitable shapes (colHeight >= one full column tile, cOut*colWidth*colHeight >= blockThreshold), so tiny depthwise-style calls (cOut=1, small colHeight) stay on the scalar path. The transposed buffer is pooled (sync.Pool) and grown in place: transposeF32 fully overwrites it each call, so reuse without re-zeroing is safe and avoids a large alloc + zero per conv (which showed up as runtime.memclr in the profile). Default builds and non-amd64 are unchanged (gemmF32 nil -> scalar). A GOEXPERIMENT=simd build also keeps gemmF32 nil (the avo GEMM is gated !goexperiment.simd), so conv stays scalar there, consistent with the foundation. Tests cover SIMD-vs-scalar parity and dispatch routing (regular conv -> gemm, depthwise/narrow -> scalar); the routing test is tagged amd64 && !goexperiment.simd to match the gemmF32/gemmAVX2F32 symbols it exercises. Part of born-ml#79. Builds on the GEMM foundation (born-ml#96); final piece split out of born-ml#80.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
kolkov
left a comment
There was a problem hiding this comment.
Clean integration — the transpose + GEMM routing is mathematically correct and the dispatch thresholds are conservative.
Correctness verified
out[cOut, colHeight] = kernel[cOut, colWidth] @ colBufT[colWidth, colHeight] — matches the original scalar loop. transposeF32 correctly maps [colHeight, colWidth] row-major to [colWidth, colHeight] row-major. gemmF32(out, kernel, colBufT, cOut, colWidth, colHeight) maps to C[m,n] = A[m,k] @ B[k,n]. ✓
Dispatch
colHeight >= gemmMinCols && cOut*colWidth*colHeight >= blockThreshold — correct. colHeight is n in gemmF32 terms. Small depthwise-style calls (cOut=1) stay scalar.
One request: add a comment clarifying the mapping, e.g.:
// colHeight is n in gemmF32(c, a, b, m, k, n); gemmMinCols ensures a full 16-wide column tile.Pool
sync.Pool for transpose scratch — safe. Full overwrite by transposeF32 makes dirty reuse correct. Matches gemmScratch pool pattern from #96.
Tests
- Sentinel
calledflag proves SIMD dispatch — correct pattern from GEMM tests - Parity tolerance
1e-4relative — appropriate for float32 dot products of length ~288 - BirdNET shapes covered + scalar-stay conditions
- No
t.Parallel()— safe forgemmF32mutation
Build tags
//go:build amd64 && !goexperiment.simd on test file — consistent with GEMM. conv_helpers.go has no tag — gemmF32 != nil runtime guard handles all platforms.
Approved. Add the colHeight = n mapping comment, then merge.
Summary
Final focused PR from #79, split out of the WIP in #80: route the convolution im2col matmul through the vendored AVX2+FMA GEMM kernel landed in #96.
matMulColBufFloat32computesout = kernel[cOut,colWidth] @ colBuf^T. When the GEMM kernel is wired in (gemmF32 != nil, i.e. an amd64 AVX2+FMA default build), it transposescolBufso the reduction axis becomes the row axis and reusesgemmF32instead of the scalar triple-loop. Guarded to profitable shapes (colHeight >= gemmMinCols,cOut*colWidth*colHeight >= blockThreshold), so tiny depthwise-style calls (cOut=1, smallcolHeight) stay on the scalar path.On the BirdNET v2.4 regular-conv shapes this is roughly 16-22x faster per call (and ~2x on small-
colHeightshapes).Pooled transpose buffer
The transposed
colBuf^T(up to a few MB for the larger 3x3 convs) was allocated fresh per call, and Go zeroes it at allocation even thoughtransposeF32immediately overwrites every element (it showed up asruntime.memclrin the profile). It is now pooled (sync.Pool) and grown in place; full overwrite each call makes un-zeroed reuse safe, and one buffer serves every conv.Build behavior
gemmF32nil, scalar path unchanged.GOEXPERIMENT=simd:gemmF32is nil there too (the avo GEMM is gated!goexperiment.simd, per feat(cpu): default-buildable AVX2+FMA GEMM kernel with always-on dispatch #96), so conv stays scalar, consistent with the foundation. The routing test is taggedamd64 && !goexperiment.simdto match thegemmF32/gemmAVX2F32symbols it exercises.Test plan
go test -short -race ./...(all 19 packages; default build exercises the conv SIMD routing)matMulColBufFloat32across conv shapesGOEXPERIMENT=simd go build ./...and tests pass (conv stays scalar, routing test excluded by tag)go vet,golangci-lint,gofmtcleanBuilds on #96. With this, #80 is fully superseded and can be closed.
Part of #79.