diff --git a/internal/backend/cpu/conv2d.go b/internal/backend/cpu/conv2d.go index 639572a..c5dfbe0 100644 --- a/internal/backend/cpu/conv2d.go +++ b/internal/backend/cpu/conv2d.go @@ -140,20 +140,26 @@ func conv2dFloat32Stride1NoPad(output, input, kernel *tensor.RawTensor, dims *Co HOut := dims.HOut WOut := dims.WOut - // Step 1: Im2col with stride=1, padding=0 + // Step 1: Im2col with stride=1, padding=0. colBuf is recycled from a pool and + // fully overwritten by im2col, so it needs no zeroing. colWidth := CIn * KH * KW colHeight := N * HOut * WOut - colBuf := make([]float32, colHeight*colWidth) + colp := poolScratch[float32](&convColPoolF32, colHeight*colWidth) + colBuf := *colp + defer convColPoolF32.Put(colp) im2colFloat32Stride1NoPad(colBuf, inputData, dims) - // Step 2: Matrix multiplication via helper (inlined by compiler). - matMulColBufFloat32(outputData, kernelData, colBuf, COut, colHeight, colWidth) + // Step 2: Matrix multiply into pooled scratch (len == len(outputData)). + // matMulColBufFloat32 writes every element, so the un-zeroed buffer is safe. + matp := poolScratch[float32](&convOutPoolF32, COut*colHeight) + matOut := *matp + defer convOutPoolF32.Put(matp) + matMulColBufFloat32(matOut, kernelData, colBuf, COut, colHeight, colWidth) - // Step 3: Rearrange from [C_out, N*H_out*W_out] to [N, C_out, H_out, W_out]. - tempBuf := make([]float32, len(outputData)) - copy(tempBuf, outputData) - rearrangeOutputFloat32(outputData, tempBuf, N, COut, HOut, WOut, colHeight) + // Step 3: Rearrange [C_out, N*H_out*W_out] -> [N, C_out, H_out, W_out], + // permuting the matmul scratch straight into the output (no intermediate copy). + rearrangeOutputFloat32(outputData, matOut, N, COut, HOut, WOut, colHeight) } // pointwiseConvFloat32 computes a 1x1 convolution (stride=1, padding=0) as a @@ -196,24 +202,29 @@ func conv2dFloat32General(output, input, kernel *tensor.RawTensor, dims *ConvDim HOut := dims.HOut WOut := dims.WOut - // Step 1: Im2col transformation - // colBuf: [N * H_out * W_out, C_in * K_h * K_w] + // Step 1: Im2col transformation. colBuf: [N*H_out*W_out, C_in*K_h*K_w]. + // Pooled scratch, fully overwritten by im2col, so no zeroing needed. colWidth := CIn * KH * KW colHeight := N * HOut * WOut - colBuf := make([]float32, colHeight*colWidth) + colp := poolScratch[float32](&convColPoolF32, colHeight*colWidth) + colBuf := *colp + defer convColPoolF32.Put(colp) im2colFloat32(colBuf, inputData, dims) // Step 2: Reshape kernel — already in [C_out, C_in*K_h*K_w] layout (row-major). - // Step 3: Matrix multiplication via helper (inlined by compiler). - // kernel: [C_out, C_in*K_h*K_w] @ colBuf^T -> [C_out, N*H_out*W_out] - matMulColBufFloat32(outputData, kernelData, colBuf, COut, colHeight, colWidth) - - // Step 4: Rearrange from [C_out, N*H_out*W_out] to [N, C_out, H_out, W_out]. - tempBuf := make([]float32, len(outputData)) - copy(tempBuf, outputData) - rearrangeOutputFloat32(outputData, tempBuf, N, COut, HOut, WOut, colHeight) + // Step 3: Matrix multiply into pooled scratch (len == len(outputData)). + // kernel: [C_out, C_in*K_h*K_w] @ colBuf^T -> [C_out, N*H_out*W_out]. + // matMulColBufFloat32 writes every element, so the un-zeroed buffer is safe. + matp := poolScratch[float32](&convOutPoolF32, COut*colHeight) + matOut := *matp + defer convOutPoolF32.Put(matp) + matMulColBufFloat32(matOut, kernelData, colBuf, COut, colHeight, colWidth) + + // Step 4: Rearrange [C_out, N*H_out*W_out] -> [N, C_out, H_out, W_out], + // permuting the matmul scratch straight into the output (no intermediate copy). + rearrangeOutputFloat32(outputData, matOut, N, COut, HOut, WOut, colHeight) } // im2colFloat32Stride1NoPad is optimized for stride=1, padding=0. @@ -370,19 +381,25 @@ func conv2dFloat64Stride1NoPad(output, input, kernel *tensor.RawTensor, dims *Co HOut := dims.HOut WOut := dims.WOut - // Im2col with stride=1, padding=0 + // Im2col with stride=1, padding=0. Pooled scratch, fully overwritten by + // im2col, so no zeroing needed. colWidth := CIn * KH * KW colHeight := N * HOut * WOut - colBuf := make([]float64, colHeight*colWidth) + colp := poolScratch[float64](&convColPoolF64, colHeight*colWidth) + colBuf := *colp + defer convColPoolF64.Put(colp) im2colFloat64Stride1NoPad(colBuf, inputData, dims) - // MatMul via helper (inlined by compiler). - matMulColBufFloat64(outputData, kernelData, colBuf, COut, colHeight, colWidth) + // MatMul into pooled scratch (len == len(outputData)); matMulColBufFloat64 + // writes every element, so the un-zeroed buffer is safe. + matp := poolScratch[float64](&convOutPoolF64, COut*colHeight) + matOut := *matp + defer convOutPoolF64.Put(matp) + matMulColBufFloat64(matOut, kernelData, colBuf, COut, colHeight, colWidth) - // Rearrange from [C_out, N*H_out*W_out] to [N, C_out, H_out, W_out]. - tempBuf := make([]float64, len(outputData)) - copy(tempBuf, outputData) - rearrangeOutputFloat64(outputData, tempBuf, N, COut, HOut, WOut, colHeight) + // Rearrange [C_out, N*H_out*W_out] -> [N, C_out, H_out, W_out], permuting the + // matmul scratch straight into the output (no intermediate copy). + rearrangeOutputFloat64(outputData, matOut, N, COut, HOut, WOut, colHeight) } // pointwiseConvFloat64 is the float64 counterpart of pointwiseConvFloat32 and @@ -418,19 +435,24 @@ func conv2dFloat64General(output, input, kernel *tensor.RawTensor, dims *ConvDim HOut := dims.HOut WOut := dims.WOut - // Im2col + // Im2col. Pooled scratch, fully overwritten by im2col, so no zeroing needed. colWidth := CIn * KH * KW colHeight := N * HOut * WOut - colBuf := make([]float64, colHeight*colWidth) + colp := poolScratch[float64](&convColPoolF64, colHeight*colWidth) + colBuf := *colp + defer convColPoolF64.Put(colp) im2colFloat64(colBuf, inputData, dims) - // MatMul via helper (inlined by compiler). - matMulColBufFloat64(outputData, kernelData, colBuf, COut, colHeight, colWidth) + // MatMul into pooled scratch (len == len(outputData)); matMulColBufFloat64 + // writes every element, so the un-zeroed buffer is safe. + matp := poolScratch[float64](&convOutPoolF64, COut*colHeight) + matOut := *matp + defer convOutPoolF64.Put(matp) + matMulColBufFloat64(matOut, kernelData, colBuf, COut, colHeight, colWidth) - // Rearrange from [C_out, N*H_out*W_out] to [N, C_out, H_out, W_out]. - tempBuf := make([]float64, len(outputData)) - copy(tempBuf, outputData) - rearrangeOutputFloat64(outputData, tempBuf, N, COut, HOut, WOut, colHeight) + // Rearrange [C_out, N*H_out*W_out] -> [N, C_out, H_out, W_out], permuting the + // matmul scratch straight into the output (no intermediate copy). + rearrangeOutputFloat64(outputData, matOut, N, COut, HOut, WOut, colHeight) } // im2colFloat64Stride1NoPad is optimized for stride=1, padding=0. diff --git a/internal/backend/cpu/conv2d_pooling_test.go b/internal/backend/cpu/conv2d_pooling_test.go new file mode 100644 index 0000000..8330d23 --- /dev/null +++ b/internal/backend/cpu/conv2d_pooling_test.go @@ -0,0 +1,241 @@ +package cpu + +import ( + "math" + "testing" + + "github.com/born-ml/born/internal/tensor" +) + +// convScratchCase describes a regular (non-1x1) convolution that exercises the +// im2col path in conv2dFloat32 / conv2dFloat64 (colBuf + matmul + rearrange). +type convScratchCase struct { + name string + n, cIn, h, w, cOut, kh, kw int + stride, padding int +} + +// Cases cover both specialized im2col paths (stride=1/pad=0 and the general +// path), at a size that routes through the SIMD GEMM and one that stays scalar. +var convScratchCases = []convScratchCase{ + {"stride1nopad_gemm", 1, 8, 16, 16, 32, 3, 3, 1, 0}, // colHeight*cOut*colWidth >= blockThreshold -> SIMD GEMM + {"stride1nopad_scalar", 1, 4, 10, 10, 6, 3, 3, 1, 0}, // small -> scalar matmul + {"padded_general", 1, 8, 12, 12, 16, 3, 3, 1, 1}, // general path with padding + {"strided_general", 1, 8, 16, 16, 16, 3, 3, 2, 0}, // general path with stride 2 +} + +// buildConvScratch builds a pre-allocated output plus input/kernel tensors and +// the ConvDims for a case, so a test can call conv2dFloat32/conv2dFloat64 +// directly (white-box) without the per-call output-tensor allocation that +// backend.Conv2D adds. +func buildConvScratch(c convScratchCase, dt tensor.DataType) (output, input, kernel *tensor.RawTensor, dims *ConvDims) { + hOut := (c.h+2*c.padding-c.kh)/c.stride + 1 + wOut := (c.w+2*c.padding-c.kw)/c.stride + 1 + input, _ = tensor.NewRaw(tensor.Shape{c.n, c.cIn, c.h, c.w}, dt, tensor.CPU) + kernel, _ = tensor.NewRaw(tensor.Shape{c.cOut, c.cIn, c.kh, c.kw}, dt, tensor.CPU) + output, _ = tensor.NewRaw(tensor.Shape{c.n, c.cOut, hOut, wOut}, dt, tensor.CPU) + fillPointwiseConv(input, func(i int) float64 { return float64((i%13)-6) * 0.25 }) + fillPointwiseConv(kernel, func(i int) float64 { return float64((i%7)-3) * 0.5 }) + dims = &ConvDims{ + N: c.n, CIn: c.cIn, H: c.h, W: c.w, + COut: c.cOut, KH: c.kh, KW: c.kw, + HOut: hOut, WOut: wOut, + Stride: c.stride, Padding: c.padding, + } + return output, input, kernel, dims +} + +func runConvScratch(dt tensor.DataType, output, input, kernel *tensor.RawTensor, dims *ConvDims) { + switch dt { + case tensor.Float32: + conv2dFloat32(output, input, kernel, dims) + case tensor.Float64: + conv2dFloat64(output, input, kernel, dims) + } +} + +// TestConv2DScratchAllocFree verifies the im2col conv path recycles its colBuf +// and matmul-output scratch from a pool: after the pool warms, a convolution +// into a pre-allocated output allocates nothing per call. Guarded with +// testing.Short because testing.AllocsPerRun over a shared sync.Pool is flaky +// under -short -race (CI runs -short -race). +func TestConv2DScratchAllocFree(t *testing.T) { + if testing.Short() || raceEnabled { + t.Skip("AllocsPerRun over a shared sync.Pool is unreliable under -short and the race detector") + } + for _, dt := range []tensor.DataType{tensor.Float32, tensor.Float64} { + for _, c := range convScratchCases { + t.Run(dt.String()+"/"+c.name, func(t *testing.T) { + output, input, kernel, dims := buildConvScratch(c, dt) + allocs := testing.AllocsPerRun(20, func() { + runConvScratch(dt, output, input, kernel, dims) + }) + if allocs != 0 { + t.Errorf("conv allocates %.0f scratch buffers/op, want 0", allocs) + } + }) + } + } +} + +// convDataEqual reports the first index where two same-dtype conv outputs differ +// bit-for-bit, or (-1, true) if identical. +func convDataEqual(a, b *tensor.RawTensor) (idx int, equal bool) { + switch a.DType() { + case tensor.Float32: + ad, bd := a.AsFloat32(), b.AsFloat32() + for i := range ad { + if ad[i] != bd[i] { + return i, false + } + } + case tensor.Float64: + ad, bd := a.AsFloat64(), b.AsFloat64() + for i := range ad { + if ad[i] != bd[i] { + return i, false + } + } + } + return -1, true +} + +// TestConv2DPooledReuseDeterministic runs each convolution repeatedly, sharing +// the recycled scratch pools across calls, and asserts every run is bit-for-bit +// identical to the first. A dirty buffer leaking through reuse would diverge. +func TestConv2DPooledReuseDeterministic(t *testing.T) { + backend := New() + for _, dt := range []tensor.DataType{tensor.Float32, tensor.Float64} { + for _, c := range convScratchCases { + t.Run(dt.String()+"/"+c.name, func(t *testing.T) { + _, input, kernel, _ := buildConvScratch(c, dt) + first := backend.Conv2D(input, kernel, c.stride, c.padding) + for i := 0; i < 8; i++ { + got := backend.Conv2D(input, kernel, c.stride, c.padding) + if idx, ok := convDataEqual(first, got); !ok { + t.Fatalf("run %d diverged from first at index %d", i+1, idx) + } + } + }) + } + } +} + +// poisonConvPools pre-dirties the recycled scratch pools for dtype dt with a +// sentinel at exactly the sizes the next conv of this shape will request, so the +// conv reuses the poisoned buffers. If im2col and the matmul fully overwrite +// their buffers (the pooling safety contract), the sentinel never reaches the +// output. +func poisonConvPools(dt tensor.DataType, colN, outN int) { + const sentinel = -123456.0 + switch dt { + case tensor.Float32: + cp := poolScratch[float32](&convColPoolF32, colN) + for i := range *cp { + (*cp)[i] = sentinel + } + convColPoolF32.Put(cp) + mp := poolScratch[float32](&convOutPoolF32, outN) + for i := range *mp { + (*mp)[i] = sentinel + } + convOutPoolF32.Put(mp) + case tensor.Float64: + cp := poolScratch[float64](&convColPoolF64, colN) + for i := range *cp { + (*cp)[i] = sentinel + } + convColPoolF64.Put(cp) + mp := poolScratch[float64](&convOutPoolF64, outN) + for i := range *mp { + (*mp)[i] = sentinel + } + convOutPoolF64.Put(mp) + } +} + +// TestConv2DPooledPoisonedOverwrite poisons the recycled scratch pools with a +// sentinel, then asserts the conv output is bit-identical to a clean run. This +// proves im2col and the matmul fully overwrite the recycled buffers, so dirty +// reuse is safe (the pattern proven for the GEMM kernel in #96 and #99). +func TestConv2DPooledPoisonedOverwrite(t *testing.T) { + backend := New() + for _, dt := range []tensor.DataType{tensor.Float32, tensor.Float64} { + for _, c := range convScratchCases { + t.Run(dt.String()+"/"+c.name, func(t *testing.T) { + _, input, kernel, dims := buildConvScratch(c, dt) + clean := backend.Conv2D(input, kernel, c.stride, c.padding) + + colN := dims.CIn * dims.KH * dims.KW * (dims.N * dims.HOut * dims.WOut) + outN := dims.COut * (dims.N * dims.HOut * dims.WOut) + poisonConvPools(dt, colN, outN) + + got := backend.Conv2D(input, kernel, c.stride, c.padding) + if idx, ok := convDataEqual(clean, got); !ok { + t.Fatalf("poisoned scratch leaked into output at index %d", idx) + } + }) + } + } +} + +// TestConv2DPooledMatchesMock checks the pooled im2col path against the naive +// MockBackend oracle across the regular-conv shapes (both specialized paths, +// scalar and SIMD-GEMM routing), for both dtypes. +func TestConv2DPooledMatchesMock(t *testing.T) { + backend := New() + mock := tensor.NewMockBackend() + for _, dt := range []tensor.DataType{tensor.Float32, tensor.Float64} { + // Relative tolerance: float32 GEMM reorders FMA accumulation; float64 is + // effectively exact. + tol := 1e-4 + if dt == tensor.Float64 { + tol = 1e-12 + } + for _, c := range convScratchCases { + t.Run(dt.String()+"/"+c.name, func(t *testing.T) { + _, input, kernel, _ := buildConvScratch(c, dt) + got := backend.Conv2D(input, kernel, c.stride, c.padding) + want := mock.Conv2D(input, kernel, c.stride, c.padding) + if !got.Shape().Equal(want.Shape()) { + t.Fatalf("shape: CPU=%v Mock=%v", got.Shape(), want.Shape()) + } + if d, idx := maxPointwiseConvDiff(got, want); d > tol*(1+absAt(want, idx)) { + t.Errorf("idx %d: abs diff %.3g exceeds rel tol %.3g", idx, d, tol) + } + }) + } + } +} + +// absAt returns |value| at flat index i of a same-dtype tensor (helper for a +// relative-tolerance check). +func absAt(t *tensor.RawTensor, i int) float64 { + if i < 0 { + return 0 + } + switch t.DType() { + case tensor.Float32: + return math.Abs(float64(t.AsFloat32()[i])) + case tensor.Float64: + return math.Abs(t.AsFloat64()[i]) + } + return 0 +} + +func benchConv2DIm2col(b *testing.B, n, cIn, h, w, cOut, k, stride, pad int) { + backend := New() + input := tensor.Randn[float32](tensor.Shape{n, cIn, h, w}, backend).Raw() + kernel := tensor.Randn[float32](tensor.Shape{cOut, cIn, k, k}, backend).Raw() + b.ResetTimer() + for b.Loop() { + backend.Conv2D(input, kernel, stride, pad) + } +} + +// Regular (KxK, K>1) conv layers that exercise the im2col + matmul + rearrange +// path whose colBuf and matmul-output scratch are pooled. Run with -benchmem to +// see the per-conv allocation and B/op drop from recycling those buffers. +func BenchmarkConv2D_Im2col_GEMM(b *testing.B) { benchConv2DIm2col(b, 1, 32, 64, 64, 64, 3, 1, 1) } +func BenchmarkConv2D_Im2col_Deep(b *testing.B) { benchConv2DIm2col(b, 1, 64, 32, 32, 128, 3, 1, 1) } +func BenchmarkConv2D_Im2col_Strided(b *testing.B) { benchConv2DIm2col(b, 1, 16, 64, 64, 32, 3, 2, 1) } diff --git a/internal/backend/cpu/conv_helpers.go b/internal/backend/cpu/conv_helpers.go index 62ada0f..a177340 100644 --- a/internal/backend/cpu/conv_helpers.go +++ b/internal/backend/cpu/conv_helpers.go @@ -7,6 +7,37 @@ import "sync" // so a pooled (un-zeroed) buffer is safe and avoids a large alloc + zero per conv. var colBufTPool = sync.Pool{New: func() any { s := []float32(nil); return &s }} +// The im2col conv path (conv2dFloat32 / conv2dFloat64) recycles two more large +// ephemeral buffers per call: convColPool* holds the im2col column buffer, and +// convOutPool* holds the matmul output that feeds the NCHW rearrange. Both are +// fully overwritten before any read (im2col writes every column entry, including +// padding zeros; matMulColBuf* writes every output[i*colHeight+j]), so a pooled +// un-zeroed buffer is safe and skips a make + memclr per conv. Recycling the +// matmul output also removes the old copy(outputData -> tempBuf) memmove, because +// the matmul writes the scratch directly and rearrange then permutes it into the +// output. Mirrors colBufTPool above and the gemmScratch pool from the GEMM kernel. +var ( + convColPoolF32 = sync.Pool{New: func() any { s := []float32(nil); return &s }} + convOutPoolF32 = sync.Pool{New: func() any { s := []float32(nil); return &s }} + convColPoolF64 = sync.Pool{New: func() any { s := []float64(nil); return &s }} + convOutPoolF64 = sync.Pool{New: func() any { s := []float64(nil); return &s }} +) + +// poolScratch returns a length-n slice backed by a pooled array from p, growing +// the array only when the cached capacity is too small. The slice length is +// exactly n (no backing-slice slack leaks into indexing), and its contents are +// NOT zeroed: callers must fully overwrite the slice before reading it, then +// return the pointer to p with Put. +func poolScratch[T any](p *sync.Pool, n int) *[]T { + sp := p.Get().(*[]T) + if cap(*sp) < n { + *sp = make([]T, n) + } else { + *sp = (*sp)[:n] + } + return sp +} + // conv_helpers.go — inner-loop helper functions for Conv2D and MaxPool2D. // // These helpers are extracted from the innermost loops to reduce cognitive @@ -23,17 +54,11 @@ func matMulColBufFloat32(outputData, kernelData, colBuf []float32, cOut, colHeig // vendored GEMM kernel. Guarded to profitable shapes (the kernel needs a full // column tile); tiny depthwise-style calls (cOut=1, small colHeight) stay scalar. if gemmF32 != nil && colHeight >= gemmMinCols && cOut*colWidth*colHeight >= blockThreshold { - need := colWidth * colHeight - p := colBufTPool.Get().(*[]float32) - if cap(*p) < need { - *p = make([]float32, need) - } else { - *p = (*p)[:need] - } + p := poolScratch[float32](&colBufTPool, colWidth*colHeight) + defer colBufTPool.Put(p) colBufT := *p transposeF32(colBufT, colBuf, colHeight, colWidth) // fully overwrites colBufT gemmF32(outputData, kernelData, colBufT, cOut, colWidth, colHeight) - colBufTPool.Put(p) return } for i := 0; i < cOut; i++ { diff --git a/internal/backend/cpu/race_off_test.go b/internal/backend/cpu/race_off_test.go new file mode 100644 index 0000000..5c9768b --- /dev/null +++ b/internal/backend/cpu/race_off_test.go @@ -0,0 +1,7 @@ +//go:build !race + +package cpu + +// raceEnabled reports whether the test binary was built with -race. See +// race_on_test.go for why alloc-counting tests consult it. +const raceEnabled = false diff --git a/internal/backend/cpu/race_on_test.go b/internal/backend/cpu/race_on_test.go new file mode 100644 index 0000000..04dba9a --- /dev/null +++ b/internal/backend/cpu/race_on_test.go @@ -0,0 +1,8 @@ +//go:build race + +package cpu + +// raceEnabled reports whether the test binary was built with -race. The race +// detector adds shadow allocations that make testing.AllocsPerRun unreliable, so +// alloc-counting tests skip when it is on (independently of -short). +const raceEnabled = true