diff --git a/internal/onnx/operators/_gen/pow/go.mod b/internal/onnx/operators/_gen/pow/go.mod new file mode 100644 index 0000000..d6e6655 --- /dev/null +++ b/internal/onnx/operators/_gen/pow/go.mod @@ -0,0 +1,11 @@ +module borngen/pow + +go 1.26.0 + +require github.com/mmcloughlin/avo v0.6.0 + +require ( + golang.org/x/mod v0.37.0 // indirect + golang.org/x/sync v0.21.0 // indirect + golang.org/x/tools v0.46.0 // indirect +) diff --git a/internal/onnx/operators/_gen/pow/go.sum b/internal/onnx/operators/_gen/pow/go.sum new file mode 100644 index 0000000..b549ac8 --- /dev/null +++ b/internal/onnx/operators/_gen/pow/go.sum @@ -0,0 +1,10 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/mmcloughlin/avo v0.6.0 h1:QH6FU8SKoTLaVs80GA8TJuLNkUYl4VokHKlPhVDg4YY= +github.com/mmcloughlin/avo v0.6.0/go.mod h1:8CoAGaCSYXtCPR+8y18Y9aB/kxb8JSS6FRI7mSkvD+8= +golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= +golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk= +golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys= diff --git a/internal/onnx/operators/_gen/pow/main.go b/internal/onnx/operators/_gen/pow/main.go new file mode 100644 index 0000000..cc91c99 --- /dev/null +++ b/internal/onnx/operators/_gen/pow/main.go @@ -0,0 +1,220 @@ +// Command pow generates the vendored AVX2+FMA float32 pow(x, c) kernel for the +// operators package. Run via `go generate` (see pow_simd_amd64.go); lives in a +// separate module (_gen/pow/go.mod) so avo never enters born's module graph. The +// generated artifacts (pow_simd_amd64.s and its Go stub) are committed. +// +// powConstF32AVX2 computes out[i] = pow(in[i], c) = exp(c*log(in[i])) for n +// (multiple of 8) float32 lanes, 8 at a time, with a constant exponent c. log is +// the Cephes single-precision logf (frexp + minimax polynomial) and exp is the +// Cephes expf with Cody-Waite range reduction; both are ~1 ULP in float32, so the +// composed result is well inside the model's 1e-3 parity budget. Non-positive +// inputs are flushed to 0 (pow(0, c>0) == 0; the bitwise frexp cannot represent +// 0/negatives), matching math.Pow over the non-negative ONNX Pow domain. +package main + +import ( + "fmt" + "math" + + . "github.com/mmcloughlin/avo/build" + . "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +// Cephes single-precision logf constants (Moshier). +const ( + logP0 = 7.0376836292e-2 + logP1 = -1.1514610310e-1 + logP2 = 1.1676998740e-1 + logP3 = -1.2420140846e-1 + logP4 = 1.4249322787e-1 + logP5 = -1.6668057665e-1 + logP6 = 2.0000714765e-1 + logP7 = -2.4999993993e-1 + logP8 = 3.3333331174e-1 + sqrtHF = 0.7071067690849304 + + // ln(2) split, shared by logf and the expf range reduction. + ln2hi = 0.693359375 + ln2lo = -2.12194440e-4 + + // Cephes expf constants. + expHi = 88.3762626647949 + expLo = -88.3762626647949 + log2ef = 1.44269504088896341 + expP0 = 1.9875691500e-4 + expP1 = 1.3981999507e-3 + expP2 = 8.3334519073e-3 + expP3 = 4.1665795894e-2 + expP4 = 1.6666665459e-1 + expP5 = 5.0000001201e-1 + + half = 0.5 + one = 1.0 +) + +var f32pool = map[uint32]Mem{} + +// cf returns a RODATA Mem holding val as a single float32, deduplicated by bits. +func cf(val float32) Mem { + bits := math.Float32bits(val) + if m, ok := f32pool[bits]; ok { + return m + } + m := GLOBL(fmt.Sprintf("powf32_%08x", bits), RODATA|NOPTR) + DATA(0, U32(bits)) + f32pool[bits] = m + return m +} + +var i32pool = map[uint32]Mem{} + +// ci returns a RODATA Mem holding val as a single uint32, deduplicated. +func ci(val uint32) Mem { + if m, ok := i32pool[val]; ok { + return m + } + m := GLOBL(fmt.Sprintf("powi32_%08x", val), RODATA|NOPTR) + DATA(0, U32(val)) + i32pool[val] = m + return m +} + +// bcf broadcasts a float32 constant into a fresh YMM. +func bcf(val float32) reg.VecVirtual { + y := YMM() + VBROADCASTSS(cf(val), y) + return y +} + +// bci broadcasts a uint32 constant into a fresh YMM for integer-lane ops (AVX2 +// VP* instructions have no embedded broadcast, so materialize a full vector). +func bci(val uint32) reg.VecVirtual { + y := YMM() + VPBROADCASTD(ci(val), y) + return y +} + +func main() { + TEXT("powConstF32AVX2", NOSPLIT, "func(out, in []float32, n int, c float32)") + Doc("powConstF32AVX2 computes out[i] = pow(in[i], c) = exp(c*log(in[i])) for the", + "first n (multiple of 8) float32 lanes using AVX2+FMA Cephes logf and expf.", + "Non-positive inputs are flushed to 0. The caller handles any sub-8 remainder.") + Pragma("noescape") + + outPtr := Load(Param("out").Base(), GP64()) + inPtr := Load(Param("in").Base(), GP64()) + n := Load(Param("n"), GP64()) + cScalar := Load(Param("c"), XMM()) + cVec := YMM() + VBROADCASTSS(cScalar, cVec) + + blocks := GP64() + MOVQ(n, blocks) + SHRQ(Imm(3), blocks) // n / 8 + + zero := YMM() + VXORPS(zero, zero, zero) + + Label("loop") + CMPQ(blocks, Imm(0)) + JE(LabelRef("done")) + + x := YMM() + VMOVUPS(Mem{Base: inPtr}, x) + + // Mask of strictly-positive lanes; applied at the end to flush x<=0 to 0. + posMask := YMM() + VCMPPS(Imm(0x1e), zero, x, posMask) // GT_OQ: x > 0 + + // ---- Cephes logf(x) -> lg ---- + // frexp: e = (bits >> 23) - 126; m = (bits & 0x007fffff) | 0x3f000000 (in [0.5,1)). + ei := YMM() + VPSRLD(Imm(23), x, ei) + VPSUBD(bci(126), ei, ei) + ef := YMM() + VCVTDQ2PS(ei, ef) + + m := YMM() + VPAND(bci(0x007fffff), x, m) + VPOR(bci(0x3f000000), m, m) + + // Branchless SQRTHF adjust: if m < SQRTHF { e -= 1; m = 2m-1 } else { m = m-1 }. + ltMask := YMM() + VCMPPS(Imm(1), bcf(sqrtHF), m, ltMask) // LT_OS: m < SQRTHF + adj := YMM() + VANDPS(ltMask, bcf(one), adj) + VSUBPS(adj, ef, ef) // e -= (m 2m (true) or m (false) + VSUBPS(bcf(one), m, m) // m -= 1 + + z := YMM() + VMULPS(m, m, z) // z = m^2 + + // Horner: poly = (((((((P0*m+P1)*m+P2)...)*m+P8) + poly := bcf(logP0) + for _, p := range []float32{logP1, logP2, logP3, logP4, logP5, logP6, logP7, logP8} { + VFMADD213PS(bcf(p), m, poly) // poly = poly*m + p + } + lg := YMM() + VMULPS(poly, m, lg) // poly * m + VMULPS(lg, z, lg) // * m^2 -> poly * m^3 + + // Corrections: lg += e*ln2lo - 0.5*z + m + e*ln2hi. + VFMADD231PS(bcf(ln2lo), ef, lg) // lg += e*ln2lo + VFMADD231PS(bcf(-half), z, lg) // lg += -0.5*z + VADDPS(m, lg, lg) // lg += m + VFMADD231PS(bcf(ln2hi), ef, lg) // lg += e*ln2hi + + // ---- arg = c * log(x) ---- + arg := YMM() + VMULPS(cVec, lg, arg) + + // ---- Cephes expf(arg) -> y ---- + VMINPS(bcf(expHi), arg, arg) + VMAXPS(bcf(expLo), arg, arg) + + fx := YMM() + VMULPS(bcf(log2ef), arg, fx) + VADDPS(bcf(half), fx, fx) + VROUNDPS(Imm(1), fx, fx) // floor + + VFNMADD231PS(bcf(ln2hi), fx, arg) // arg -= fx*ln2hi + VFNMADD231PS(bcf(ln2lo), fx, arg) // arg -= fx*ln2lo + r := arg + + z2 := YMM() + VMULPS(r, r, z2) + + y := bcf(expP0) + for _, p := range []float32{expP1, expP2, expP3, expP4, expP5} { + VMULPS(r, y, y) + VADDPS(bcf(p), y, y) + } + VMULPS(z2, y, y) + VADDPS(r, y, y) + VADDPS(bcf(one), y, y) + + // 2^fx via integer exponent: ((int(fx)+127) << 23) reinterpreted as float. + ni := YMM() + VCVTTPS2DQ(fx, ni) + VPADDD(bci(127), ni, ni) + VPSLLD(Imm(23), ni, ni) + VMULPS(ni, y, y) // exp = y * 2^fx + + // Flush non-positive inputs to 0. + VANDPS(posMask, y, y) + VMOVUPS(y, Mem{Base: outPtr}) + + ADDQ(Imm(32), inPtr) + ADDQ(Imm(32), outPtr) + DECQ(blocks) + JMP(LabelRef("loop")) + + Label("done") + VZEROUPPER() + RET() + + Generate() +} diff --git a/internal/onnx/operators/math_ops.go b/internal/onnx/operators/math_ops.go index d70df11..ab4790c 100644 --- a/internal/onnx/operators/math_ops.go +++ b/internal/onnx/operators/math_ops.go @@ -55,6 +55,10 @@ func handlePow(_ *Context, _ *Node, inputs []*tensor.RawTensor) ([]*tensor.RawTe od := out.AsFloat32() switch { case len(e) == 1: + if powConstF32 != nil { + powConstF32(od, b, e[0]) // vendored SIMD: exp(c*log(x)), with a scalar tail + break + } ex := float64(e[0]) for i := range b { od[i] = float32(math.Pow(float64(b[i]), ex)) diff --git a/internal/onnx/operators/pow_simd.go b/internal/onnx/operators/pow_simd.go new file mode 100644 index 0000000..aa35b2d --- /dev/null +++ b/internal/onnx/operators/pow_simd.go @@ -0,0 +1,19 @@ +package operators + +// powConstF32 is the optional vendored-SIMD pow with a constant exponent: +// out[i] = pow(in[i], c), computed as exp(c*log(in[i])). It targets the ONNX Pow +// op's scalar-exponent case on non-negative inputs (e.g. mel-spectrogram power +// compression). +// +// It is nil by default and wired by an arch-specific init when the CPU supports +// AVX2+FMA (see pow_simd_amd64.go). When non-nil, handlePow uses it for the +// scalar-exponent case; otherwise the scalar math.Pow loop runs. out and in must +// have the same length. +// +// Domain: inputs must be non-negative. The kernel flushes x<=0 to 0, which +// matches math.Pow for x==0 with c>0 (the BirdNET mel-spectrogram use). It does +// NOT reproduce math.Pow on a negative base (math.Pow yields NaN for a +// non-integer exponent and the signed root for an integer one; the kernel yields +// 0), so a model that feeds a negative base into a scalar-exponent Pow would +// diverge from the scalar path on AVX2 CPUs. No born model does this. +var powConstF32 func(out, in []float32, c float32) diff --git a/internal/onnx/operators/pow_simd_amd64.go b/internal/onnx/operators/pow_simd_amd64.go new file mode 100644 index 0000000..0e3e826 --- /dev/null +++ b/internal/onnx/operators/pow_simd_amd64.go @@ -0,0 +1,36 @@ +//go:build amd64 + +package operators + +//go:generate sh -c "cd _gen/pow && go run . -out ../../pow_simd_amd64.s -stubs ../../pow_simd_stub_gen_amd64.go -pkg operators" + +import ( + "math" + + "golang.org/x/sys/cpu" +) + +// init wires the vendored AVX2+FMA pow kernel into the dispatch whenever the CPU +// supports AVX2+FMA. It compiles into every default amd64 build (no build tag or +// env flag); dispatch is decided here at startup from runtime CPU detection. CPUs +// without AVX2+FMA leave powConstF32 nil and use the scalar path. +func init() { + if cpu.X86.HasAVX2 && cpu.X86.HasFMA { + powConstF32 = powConstAVX2 + } +} + +// powConstAVX2 applies the vendored 8-wide pow(x,c) = exp(c*log(x)) kernel to the +// bulk of in and finishes the sub-8 remainder with scalar math.Pow, so any length +// is handled. c is the constant exponent; inputs are expected non-negative. +func powConstAVX2(out, in []float32, c float32) { + n := len(in) + n8 := n &^ 7 + if n8 > 0 { + powConstF32AVX2(out, in, n8, c) + } + ex := float64(c) + for i := n8; i < n; i++ { + out[i] = float32(math.Pow(float64(in[i]), ex)) + } +} diff --git a/internal/onnx/operators/pow_simd_amd64.s b/internal/onnx/operators/pow_simd_amd64.s new file mode 100644 index 0000000..fd8b997 --- /dev/null +++ b/internal/onnx/operators/pow_simd_amd64.s @@ -0,0 +1,195 @@ +// Code generated by command: go run main.go -out ../../pow_simd_amd64.s -stubs ../../pow_simd_stub_gen_amd64.go -pkg operators. DO NOT EDIT. + +#include "textflag.h" + +// func powConstF32AVX2(out []float32, in []float32, n int, c float32) +// Requires: AVX, AVX2, FMA3, SSE +TEXT ·powConstF32AVX2(SB), NOSPLIT, $0-60 + MOVQ out_base+0(FP), AX + MOVQ in_base+24(FP), CX + MOVQ n+48(FP), DX + MOVSS c+56(FP), X0 + VBROADCASTSS X0, Y0 + SHRQ $0x03, DX + VXORPS Y1, Y1, Y1 + +loop: + CMPQ DX, $0x00 + JE done + VMOVUPS (CX), Y2 + VCMPPS $0x1e, Y1, Y2, Y3 + VPSRLD $0x17, Y2, Y4 + VPBROADCASTD powi32_0000007e<>+0(SB), Y5 + VPSUBD Y5, Y4, Y4 + VCVTDQ2PS Y4, Y4 + VPBROADCASTD powi32_007fffff<>+0(SB), Y5 + VPAND Y5, Y2, Y2 + VPBROADCASTD powi32_3f000000<>+0(SB), Y5 + VPOR Y5, Y2, Y2 + VBROADCASTSS powf32_3f3504f3<>+0(SB), Y5 + VCMPPS $0x01, Y5, Y2, Y5 + VBROADCASTSS powf32_3f800000<>+0(SB), Y6 + VANDPS Y5, Y6, Y6 + VSUBPS Y6, Y4, Y4 + VANDPS Y5, Y2, Y6 + VADDPS Y6, Y2, Y2 + VBROADCASTSS powf32_3f800000<>+0(SB), Y5 + VSUBPS Y5, Y2, Y2 + VMULPS Y2, Y2, Y5 + VBROADCASTSS powf32_3d9021bb<>+0(SB), Y6 + VBROADCASTSS powf32_bdebd1b8<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_3def251a<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_bdfe5d4f<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_3e11e9bf<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_be2aae50<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_3e4cceac<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_be7ffffc<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VBROADCASTSS powf32_3eaaaaaa<>+0(SB), Y7 + VFMADD213PS Y7, Y2, Y6 + VMULPS Y6, Y2, Y6 + VMULPS Y6, Y5, Y6 + VBROADCASTSS powf32_b95e8083<>+0(SB), Y7 + VFMADD231PS Y7, Y4, Y6 + VBROADCASTSS powf32_bf000000<>+0(SB), Y7 + VFMADD231PS Y7, Y5, Y6 + VADDPS Y2, Y6, Y6 + VBROADCASTSS powf32_3f318000<>+0(SB), Y2 + VFMADD231PS Y2, Y4, Y6 + VMULPS Y0, Y6, Y2 + VBROADCASTSS powf32_42b0c0a5<>+0(SB), Y4 + VMINPS Y4, Y2, Y2 + VBROADCASTSS powf32_c2b0c0a5<>+0(SB), Y4 + VMAXPS Y4, Y2, Y2 + VBROADCASTSS powf32_3fb8aa3b<>+0(SB), Y4 + VMULPS Y4, Y2, Y4 + VBROADCASTSS powf32_3f000000<>+0(SB), Y5 + VADDPS Y5, Y4, Y4 + VROUNDPS $0x01, Y4, Y4 + VBROADCASTSS powf32_3f318000<>+0(SB), Y5 + VFNMADD231PS Y5, Y4, Y2 + VBROADCASTSS powf32_b95e8083<>+0(SB), Y5 + VFNMADD231PS Y5, Y4, Y2 + VMULPS Y2, Y2, Y5 + VBROADCASTSS powf32_39506967<>+0(SB), Y6 + VMULPS Y2, Y6, Y6 + VBROADCASTSS powf32_3ab743ce<>+0(SB), Y7 + VADDPS Y7, Y6, Y6 + VMULPS Y2, Y6, Y6 + VBROADCASTSS powf32_3c088908<>+0(SB), Y7 + VADDPS Y7, Y6, Y6 + VMULPS Y2, Y6, Y6 + VBROADCASTSS powf32_3d2aa9c1<>+0(SB), Y7 + VADDPS Y7, Y6, Y6 + VMULPS Y2, Y6, Y6 + VBROADCASTSS powf32_3e2aaaaa<>+0(SB), Y7 + VADDPS Y7, Y6, Y6 + VMULPS Y2, Y6, Y6 + VBROADCASTSS powf32_3f000000<>+0(SB), Y7 + VADDPS Y7, Y6, Y6 + VMULPS Y5, Y6, Y6 + VADDPS Y2, Y6, Y6 + VBROADCASTSS powf32_3f800000<>+0(SB), Y2 + VADDPS Y2, Y6, Y6 + VCVTTPS2DQ Y4, Y2 + VPBROADCASTD powi32_0000007f<>+0(SB), Y4 + VPADDD Y4, Y2, Y2 + VPSLLD $0x17, Y2, Y2 + VMULPS Y2, Y6, Y6 + VANDPS Y3, Y6, Y6 + VMOVUPS Y6, (AX) + ADDQ $0x20, CX + ADDQ $0x20, AX + DECQ DX + JMP loop + +done: + VZEROUPPER + RET + +DATA powi32_0000007e<>+0(SB)/4, $0x0000007e +GLOBL powi32_0000007e<>(SB), RODATA|NOPTR, $4 + +DATA powi32_007fffff<>+0(SB)/4, $0x007fffff +GLOBL powi32_007fffff<>(SB), RODATA|NOPTR, $4 + +DATA powi32_3f000000<>+0(SB)/4, $0x3f000000 +GLOBL powi32_3f000000<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3f3504f3<>+0(SB)/4, $0x3f3504f3 +GLOBL powf32_3f3504f3<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3f800000<>+0(SB)/4, $0x3f800000 +GLOBL powf32_3f800000<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3d9021bb<>+0(SB)/4, $0x3d9021bb +GLOBL powf32_3d9021bb<>(SB), RODATA|NOPTR, $4 + +DATA powf32_bdebd1b8<>+0(SB)/4, $0xbdebd1b8 +GLOBL powf32_bdebd1b8<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3def251a<>+0(SB)/4, $0x3def251a +GLOBL powf32_3def251a<>(SB), RODATA|NOPTR, $4 + +DATA powf32_bdfe5d4f<>+0(SB)/4, $0xbdfe5d4f +GLOBL powf32_bdfe5d4f<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3e11e9bf<>+0(SB)/4, $0x3e11e9bf +GLOBL powf32_3e11e9bf<>(SB), RODATA|NOPTR, $4 + +DATA powf32_be2aae50<>+0(SB)/4, $0xbe2aae50 +GLOBL powf32_be2aae50<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3e4cceac<>+0(SB)/4, $0x3e4cceac +GLOBL powf32_3e4cceac<>(SB), RODATA|NOPTR, $4 + +DATA powf32_be7ffffc<>+0(SB)/4, $0xbe7ffffc +GLOBL powf32_be7ffffc<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3eaaaaaa<>+0(SB)/4, $0x3eaaaaaa +GLOBL powf32_3eaaaaaa<>(SB), RODATA|NOPTR, $4 + +DATA powf32_b95e8083<>+0(SB)/4, $0xb95e8083 +GLOBL powf32_b95e8083<>(SB), RODATA|NOPTR, $4 + +DATA powf32_bf000000<>+0(SB)/4, $0xbf000000 +GLOBL powf32_bf000000<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3f318000<>+0(SB)/4, $0x3f318000 +GLOBL powf32_3f318000<>(SB), RODATA|NOPTR, $4 + +DATA powf32_42b0c0a5<>+0(SB)/4, $0x42b0c0a5 +GLOBL powf32_42b0c0a5<>(SB), RODATA|NOPTR, $4 + +DATA powf32_c2b0c0a5<>+0(SB)/4, $0xc2b0c0a5 +GLOBL powf32_c2b0c0a5<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3fb8aa3b<>+0(SB)/4, $0x3fb8aa3b +GLOBL powf32_3fb8aa3b<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3f000000<>+0(SB)/4, $0x3f000000 +GLOBL powf32_3f000000<>(SB), RODATA|NOPTR, $4 + +DATA powf32_39506967<>+0(SB)/4, $0x39506967 +GLOBL powf32_39506967<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3ab743ce<>+0(SB)/4, $0x3ab743ce +GLOBL powf32_3ab743ce<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3c088908<>+0(SB)/4, $0x3c088908 +GLOBL powf32_3c088908<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3d2aa9c1<>+0(SB)/4, $0x3d2aa9c1 +GLOBL powf32_3d2aa9c1<>(SB), RODATA|NOPTR, $4 + +DATA powf32_3e2aaaaa<>+0(SB)/4, $0x3e2aaaaa +GLOBL powf32_3e2aaaaa<>(SB), RODATA|NOPTR, $4 + +DATA powi32_0000007f<>+0(SB)/4, $0x0000007f +GLOBL powi32_0000007f<>(SB), RODATA|NOPTR, $4 diff --git a/internal/onnx/operators/pow_simd_stub_gen_amd64.go b/internal/onnx/operators/pow_simd_stub_gen_amd64.go new file mode 100644 index 0000000..1495037 --- /dev/null +++ b/internal/onnx/operators/pow_simd_stub_gen_amd64.go @@ -0,0 +1,10 @@ +// Code generated by command: go run main.go -out ../../pow_simd_amd64.s -stubs ../../pow_simd_stub_gen_amd64.go -pkg operators. DO NOT EDIT. + +package operators + +// powConstF32AVX2 computes out[i] = pow(in[i], c) = exp(c*log(in[i])) for the +// first n (multiple of 8) float32 lanes using AVX2+FMA Cephes logf and expf. +// Non-positive inputs are flushed to 0. The caller handles any sub-8 remainder. +// +//go:noescape +func powConstF32AVX2(out []float32, in []float32, n int, c float32) diff --git a/internal/onnx/operators/pow_simd_test.go b/internal/onnx/operators/pow_simd_test.go new file mode 100644 index 0000000..0e1b718 --- /dev/null +++ b/internal/onnx/operators/pow_simd_test.go @@ -0,0 +1,110 @@ +package operators + +import ( + "math" + "math/rand" + "testing" + + "golang.org/x/sys/cpu" +) + +// powTestInput builds a non-negative test slice covering exact zeros plus small, +// unit, and typical magnitudes, which is the domain of the ONNX Pow op on a mel +// spectrogram (base >= 0). +func powTestInput(r *rand.Rand, n int) []float32 { + s := make([]float32, n) + for i := range s { + switch i % 7 { + case 0: + s[i] = 0 // pow(0, c>0) == 0 + case 1: + s[i] = r.Float32() * 0.01 // small + case 2: + s[i] = 1 + default: + s[i] = r.Float32() * 50 // typical magnitude + } + } + return s +} + +var powExponents = []float32{0.1905273, 0.22952409, 0.43, 0.5, 1.5, 2.0, 0.1} + +// TestPowConstSIMDParity checks the vendored AVX2 pow(x,c) kernel against the +// scalar math.Pow reference across the BirdNET exponents and representative +// non-negative inputs spanning full vector blocks plus sub-8 tails. +func TestPowConstSIMDParity(t *testing.T) { + if !cpu.X86.HasAVX2 || !cpu.X86.HasFMA { + t.Skip("AVX2+FMA not available") + } + if powConstF32 == nil { + t.Skip("vendored SIMD pow not wired") + } + r := rand.New(rand.NewSource(0x706f77)) // "pow" + var globalMax float64 + for _, c := range powExponents { + for _, n := range []int{1, 7, 8, 9, 17, 256, 49056} { + in := powTestInput(r, n) + got := make([]float32, n) + powConstF32(got, in, c) + for i := range in { + want := float32(math.Pow(float64(in[i]), float64(c))) + d := math.Abs(float64(got[i]-want)) / (1 + math.Abs(float64(want))) + if d > globalMax { + globalMax = d + } + // exp(c*log(x)) reorders rounding vs float64 math.Pow; 1e-4 relative + // is comfortably inside the model's 1e-3 parity budget. + if d > 1e-4 { + t.Errorf("c=%v n=%d i=%d: rel diff %.3e exceeds 1e-4 (got %v want %v)", + c, n, i, d, got[i], want) + } + } + } + } + t.Logf("max relative error vs math.Pow: %.3e", globalMax) +} + +// BenchmarkPowConst compares scalar math.Pow against the vendored SIMD kernel at +// the BirdNET mel-spectrogram Pow size (1*511*96) and a smaller tensor. +func BenchmarkPowConst(b *testing.B) { + r := rand.New(rand.NewSource(9)) + const c = float32(0.1905273) + for _, s := range []struct { + name string + n int + }{ + {"mel_511x96", 49056}, // 1*511*96, the BirdNET Pow tensor + {"n4096", 4096}, + } { + in := powTestInput(r, s.n) + out := make([]float32, s.n) + b.Run(s.name+"/scalar", func(b *testing.B) { + ex := float64(c) + for i := 0; i < b.N; i++ { + for j := range in { + out[j] = float32(math.Pow(float64(in[j]), ex)) + } + } + }) + b.Run(s.name+"/simd", func(b *testing.B) { + if powConstF32 == nil { + b.Skip("no SIMD pow") + } + for i := 0; i < b.N; i++ { + powConstF32(out, in, c) + } + }) + } +} + +// TestPowConstWiredIn asserts the always-on dispatch contract: the vendored SIMD +// pow is wired exactly when the CPU supports AVX2+FMA, so a dropped init() or a +// flipped CPU check is caught instead of silently skipping the SIMD tests. +func TestPowConstWiredIn(t *testing.T) { + want := cpu.X86.HasAVX2 && cpu.X86.HasFMA + if got := powConstF32 != nil; got != want { + t.Errorf("powConstF32 wired = %v, want %v (HasAVX2=%v HasFMA=%v)", + got, want, cpu.X86.HasAVX2, cpu.X86.HasFMA) + } +}