diff --git a/lib/column/bigint.go b/lib/column/bigint.go index 6c9e4c26d9..be45826a2b 100644 --- a/lib/column/bigint.go +++ b/lib/column/bigint.go @@ -68,17 +68,23 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) { case []big.Int: nulls = make([]uint8, len(v)) for i := range v { - col.append(&v[i]) + if err := col.append(&v[i]); err != nil { + return nil, err + } } case []*big.Int: nulls = make([]uint8, len(v)) for i := range v { switch { case v[i] != nil: - col.append(v[i]) + if err := col.append(v[i]); err != nil { + return nil, err + } default: nulls[i] = 1 - col.append(big.NewInt(0)) + if err := col.append(big.NewInt(0)); err != nil { + return nil, err + } } } default: @@ -106,16 +112,16 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) { func (col *BigInt) AppendRow(v any) error { switch v := v.(type) { case big.Int: - col.append(&v) + return col.append(&v) case *big.Int: switch { case v != nil: - col.append(v) + return col.append(v) default: - col.append(big.NewInt(0)) + return col.append(big.NewInt(0)) } case nil: - col.append(big.NewInt(0)) + return col.append(big.NewInt(0)) default: if valuer, ok := v.(driver.Valuer); ok { val, err := valuer.Value() @@ -135,7 +141,6 @@ func (col *BigInt) AppendRow(v any) error { From: fmt.Sprintf("%T", v), } } - return nil } func (col *BigInt) Decode(reader *proto.Reader, rows int) error { @@ -177,9 +182,16 @@ func (col *BigInt) row(i int) *big.Int { return big.NewInt(0) } -func (col *BigInt) append(v *big.Int) { +func (col *BigInt) append(v *big.Int) error { dest := make([]byte, col.size) - bigIntToRaw(dest, new(big.Int).Set(v)) + if err := bigIntToRaw(dest, v, col.signed); err != nil { + return &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: "big.Int", + Hint: err.Error(), + } + } switch v := col.col.(type) { case *proto.ColInt128: v.Append(proto.Int128{ @@ -214,17 +226,39 @@ func (col *BigInt) append(v *big.Int) { }, }) } + return nil } -func bigIntToRaw(dest []byte, v *big.Int) { +func bigIntToRaw(dest []byte, v *big.Int, signed bool) error { + bits := len(dest) * 8 + if signed { + if v.Sign() >= 0 { + if v.BitLen() > bits-1 { + return fmt.Errorf("value overflows %d-byte signed buffer", len(dest)) + } + } else { + if new(big.Int).Not(v).BitLen() > bits-1 { + return fmt.Errorf("value overflows %d-byte signed buffer", len(dest)) + } + } + } else { + if v.Sign() < 0 { + return fmt.Errorf("negative value not allowed for unsigned type") + } + if v.BitLen() > bits { + return fmt.Errorf("value overflows %d-byte unsigned buffer", len(dest)) + } + } + var sign int if v.Sign() < 0 { - v.Not(v).FillBytes(dest) + new(big.Int).Not(v).FillBytes(dest) sign = -1 } else { v.FillBytes(dest) } endianSwap(dest, sign < 0) + return nil } func rawToBigInt(v []byte, signed bool) *big.Int { diff --git a/lib/column/decimal.go b/lib/column/decimal.go index 66d103d873..aab7e5e010 100644 --- a/lib/column/decimal.go +++ b/lib/column/decimal.go @@ -6,7 +6,7 @@ import ( "encoding/binary" "errors" "fmt" - "math/big" + "math" "reflect" "strconv" "strings" @@ -139,18 +139,24 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { case []decimal.Decimal: nulls = make([]uint8, len(v)) for i := range v { - col.append(&v[i]) + if err := col.append(&v[i]); err != nil { + return nil, err + } } case []*decimal.Decimal: nulls = make([]uint8, len(v)) for i := range v { switch { case v[i] != nil: - col.append(v[i]) + if err := col.append(v[i]); err != nil { + return nil, err + } default: nulls[i] = 1 value := decimal.New(0, 0) - col.append(&value) + if err := col.append(&value); err != nil { + return nil, err + } } } case []string: @@ -160,7 +166,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if err != nil { return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", v[i], err) } - col.append(&d) + if err := col.append(&d); err != nil { + return nil, err + } } case []*string: nulls = make([]uint8, len(v)) @@ -168,8 +176,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if v[i] == nil { nulls[i] = 1 value := decimal.New(0, 0) - col.append(&value) - + if err := col.append(&value); err != nil { + return nil, err + } continue } @@ -177,7 +186,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if err != nil { return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", *v[i], err) } - col.append(&d) + if err := col.append(&d); err != nil { + return nil, err + } } default: if valuer, ok := v.(driver.Valuer); ok { @@ -244,34 +255,41 @@ func (col *Decimal) AppendRow(v any) error { From: fmt.Sprintf("%T", v), } } - col.append(&value) - return nil + return col.append(&value) } -func (col *Decimal) append(v *decimal.Decimal) { +func (col *Decimal) append(v *decimal.Decimal) error { switch vCol := col.col.(type) { case *proto.ColDecimal32: - var part uint32 - part = uint32(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart()) - vCol.Append(proto.Decimal32(part)) + scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)) + bi := scaled.BigInt() + if !bi.IsInt64() || bi.Int64() > math.MaxInt32 || bi.Int64() < math.MinInt32 { + return fmt.Errorf("value %s overflows decimal32 range", v.String()) + } + vCol.Append(proto.Decimal32(uint32(bi.Int64()))) case *proto.ColDecimal64: - var part uint64 - part = uint64(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart()) - vCol.Append(proto.Decimal64(part)) + scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)) + bi := scaled.BigInt() + if !bi.IsInt64() { + return fmt.Errorf("value %s overflows decimal64 range", v.String()) + } + vCol.Append(proto.Decimal64(uint64(bi.Int64()))) case *proto.ColDecimal128: - var bi *big.Int - bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() + bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() dest := make([]byte, 16) - bigIntToRaw(dest, bi) + if err := bigIntToRaw(dest, bi, true); err != nil { + return fmt.Errorf("value %s overflows decimal128 range", v.String()) + } vCol.Append(proto.Decimal128{ Low: binary.LittleEndian.Uint64(dest[0 : 64/8]), High: binary.LittleEndian.Uint64(dest[64/8 : 128/8]), }) case *proto.ColDecimal256: - var bi *big.Int - bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() + bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() dest := make([]byte, 32) - bigIntToRaw(dest, bi) + if err := bigIntToRaw(dest, bi, true); err != nil { + return fmt.Errorf("value %s overflows decimal256 range", v.String()) + } vCol.Append(proto.Decimal256{ Low: proto.UInt128{ Low: binary.LittleEndian.Uint64(dest[0 : 64/8]), @@ -283,6 +301,7 @@ func (col *Decimal) append(v *decimal.Decimal) { }, }) } + return nil } func (col *Decimal) Decode(reader *proto.Reader, rows int) error { diff --git a/lib/column/decimal_overflow_test.go b/lib/column/decimal_overflow_test.go new file mode 100644 index 0000000000..da4d333980 --- /dev/null +++ b/lib/column/decimal_overflow_test.go @@ -0,0 +1,92 @@ +package column + +import ( + "math/big" + "testing" + + "github.com/ClickHouse/ch-go/proto" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecimal32OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(9, 2)") + require.NoError(t, err) + + // max int32 is 2147483647; scaled by 10^2 the max representable value is ~21474836.47 + overflow, err := decimal.NewFromString("21474836.48") + require.NoError(t, err) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal64OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(18, 2)") + require.NoError(t, err) + + // max int64 is 9223372036854775807; scaled by 10^2 the max representable value is ~92233720368547758.07 + overflow, err := decimal.NewFromString("92233720368547758.08") + require.NoError(t, err) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal128OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(38, 0)") + require.NoError(t, err) + + // 2^127 exceeds Decimal128 signed range + big2_127 := new(big.Int).Lsh(big.NewInt(1), 127) + overflow := decimal.NewFromBigInt(big2_127, 0) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal256OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(76, 0)") + require.NoError(t, err) + + // 2^255 exceeds Decimal256 signed range + big2_255 := new(big.Int).Lsh(big.NewInt(1), 255) + overflow := decimal.NewFromBigInt(big2_255, 0) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestBigIntOverflowReturnsError(t *testing.T) { + // Int128: signed 128-bit, max positive is 2^127-1 + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + big2_127 := new(big.Int).Lsh(big.NewInt(1), 127) + err := col128.AppendRow(*big2_127) + assert.ErrorContains(t, err, "overflow") +} + +func TestBigIntValidValuesNoError(t *testing.T) { + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + // 2^127 - 1 is the max valid Int128 value + maxInt128 := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 127), big.NewInt(1)) + err := col128.AppendRow(*maxInt128) + assert.NoError(t, err) + + // min valid Int128 value is -2^127 + minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127)) + err = col128.AppendRow(*minInt128) + assert.NoError(t, err) +} + +func TestBigIntNegativeOverflowReturnsError(t *testing.T) { + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + // -2^127 - 1 is below the minimum Int128 value (-2^127) + minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127)) + overflow := new(big.Int).Sub(minInt128, big.NewInt(1)) + err := col128.AppendRow(*overflow) + assert.ErrorContains(t, err, "overflow") +} diff --git a/tests/issues/issue_1849_test.go b/tests/issues/issue_1849_test.go new file mode 100644 index 0000000000..f500fc5caa --- /dev/null +++ b/tests/issues/issue_1849_test.go @@ -0,0 +1,382 @@ +package issues + +import ( + "context" + "database/sql" + "fmt" + "math/big" + "strconv" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" + clickhouse_std_tests "github.com/ClickHouse/clickhouse-go/v2/tests/std" +) + +// TestDecimalOverflow verifies that appending values to a Decimal(38,0) column +// that exceed the 38-digit precision returns an error containing "overflow" +// instead of silently producing wrong data or panicking. +// +// Regression test for https://github.com/ClickHouse/clickhouse-go/issues/1849. +func TestDecimalOverflow(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 (d128 Decimal(38, 0)) Engine MergeTree() ORDER BY tuple()` + + maxDecimal128, _ := decimal.NewFromString("99999999999999999999999999999999999999") + justAboveMax, _ := decimal.NewFromString("100000000000000000000000000000000000000") + minDecimal128, _ := decimal.NewFromString("-99999999999999999999999999999999999999") + justBelowMin, _ := decimal.NewFromString("-100000000000000000000000000000000000000") + + cases := []struct { + name string + value decimal.Decimal + }{ + {"positive_overflow_above_max", justAboveMax}, + {"negative_overflow_below_min", justBelowMin}, + {"valid_max_boundary", maxDecimal128}, + {"valid_min_boundary", minDecimal128}, + } + + t.Run("Native", func(t *testing.T) { + ctx := context.Background() + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + conn, err := clickhouse_tests.GetConnection(testSet, t, protocol, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849")) + require.NoError(t, conn.Exec(ctx, ddl)) + t.Cleanup(func() { _ = conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_issue_1849") + require.NoError(t, err) + t.Cleanup(func() { _ = batch.Abort() }) + + err = batch.Append(tc.value) + assertOverflow(t, err, tc.name, "valid_max_boundary", "valid_min_boundary") + }) + } + }) + } + }) + + t.Run("Std", func(t *testing.T) { + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + db, err := clickhouse_std_tests.GetDSNConnection(testSet, protocol, useSSL, nil) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") + _, err = db.Exec(ddl) + require.NoError(t, err) + t.Cleanup(func() { _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := stdInsertOneDecimal(db, tc.value) + assertOverflow(t, err, tc.name, "valid_max_boundary", "valid_min_boundary") + }) + } + }) + } + }) +} + +// TestDecimalSilentDataCorruption verifies that Decimal32 and Decimal64 columns +// no longer silently truncate overflow values via IntPart() casts. The driver +// must return an error containing "overflow" instead of silently producing +// incorrect data. +// +// Regression test for https://github.com/ClickHouse/clickhouse-go/issues/1849. +func TestDecimalSilentDataCorruption(t *testing.T) { + t.Run("Decimal32", func(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 (d32 Decimal(9, 0)) Engine MergeTree() ORDER BY tuple()` + + maxDecimal32, _ := decimal.NewFromString("999999999") + justAboveMax32, _ := decimal.NewFromString("1000000000") + minDecimal32, _ := decimal.NewFromString("-999999999") + justBelowMin32, _ := decimal.NewFromString("-1000000000") + + cases := []struct { + name string + value decimal.Decimal + }{ + {"positive_overflow_above_max", justAboveMax32}, + {"negative_overflow_below_min", justBelowMin32}, + {"valid_max_boundary", maxDecimal32}, + {"valid_min_boundary", minDecimal32}, + } + + runDecimalOverflowTest(t, ddl, cases) + }) + + t.Run("Decimal64", func(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 (d64 Decimal(18, 0)) Engine MergeTree() ORDER BY tuple()` + + maxDecimal64, _ := decimal.NewFromString("9999999999999999999") + justAboveMax64, _ := decimal.NewFromString("10000000000000000000") + minDecimal64, _ := decimal.NewFromString("-9999999999999999999") + justBelowMin64, _ := decimal.NewFromString("-10000000000000000000") + + cases := []struct { + name string + value decimal.Decimal + }{ + {"positive_overflow_above_max", justAboveMax64}, + {"negative_overflow_below_min", justBelowMin64}, + {"valid_max_boundary", maxDecimal64}, + {"valid_min_boundary", minDecimal64}, + } + + runDecimalOverflowTest(t, ddl, cases) + }) +} + +// runDecimalOverflowTest runs the given test cases against all 4 surface +// combinations (Native TCP, Native HTTP, Std TCP, Std HTTP) using the provided DDL. +func runDecimalOverflowTest(t *testing.T, ddl string, cases []struct { + name string + value decimal.Decimal +}) { + t.Helper() + + t.Run("Native", func(t *testing.T) { + ctx := context.Background() + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + conn, err := clickhouse_tests.GetConnection(testSet, t, protocol, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849")) + require.NoError(t, conn.Exec(ctx, ddl)) + t.Cleanup(func() { _ = conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_issue_1849") + require.NoError(t, err) + t.Cleanup(func() { _ = batch.Abort() }) + + err = batch.Append(tc.value) + assertOverflow(t, err, tc.name, "valid_max_boundary", "valid_min_boundary") + }) + } + }) + } + }) + + t.Run("Std", func(t *testing.T) { + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + db, err := clickhouse_std_tests.GetDSNConnection(testSet, protocol, useSSL, nil) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") + _, err = db.Exec(ddl) + require.NoError(t, err) + t.Cleanup(func() { _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := stdInsertOneDecimal(db, tc.value) + assertOverflow(t, err, tc.name, "valid_max_boundary", "valid_min_boundary") + }) + } + }) + } + }) +} + +// TestBigIntOverflow verifies that appending values to Int128 and UInt128 +// columns that exceed the type's range returns an error containing "overflow" +// instead of panicking with "math/big: buffer too small". +// +// Regression test for https://github.com/ClickHouse/clickhouse-go/issues/1849. +func TestBigIntOverflow(t *testing.T) { + t.Run("Int128", func(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 (i128 Int128) Engine MergeTree() ORDER BY tuple()` + + maxInt128 := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 127), big.NewInt(1)) + minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127)) + justAboveMaxInt128 := new(big.Int).Add(maxInt128, big.NewInt(1)) + justBelowMinInt128 := new(big.Int).Sub(minInt128, big.NewInt(1)) + + cases := []struct { + name string + value *big.Int + }{ + {"positive_overflow", justAboveMaxInt128}, + {"negative_overflow", justBelowMinInt128}, + {"valid_max_boundary", maxInt128}, + {"valid_min_boundary", minInt128}, + } + + runBigIntOverflowTest(t, ddl, cases, false) + }) + + t.Run("UInt128", func(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 (u128 UInt128) Engine MergeTree() ORDER BY tuple()` + + maxUInt128 := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewInt(1)) + justAboveMaxUInt128 := new(big.Int).Lsh(big.NewInt(1), 128) + + cases := []struct { + name string + value *big.Int + }{ + {"negative_not_allowed", big.NewInt(-1)}, + {"positive_overflow", justAboveMaxUInt128}, + {"valid_max_boundary", maxUInt128}, + {"valid_zero", big.NewInt(0)}, + } + + runBigIntOverflowTest(t, ddl, cases, true) + }) +} + +// runBigIntOverflowTest runs the given test cases against all 4 surface +// combinations (Native TCP, Native HTTP, Std TCP, Std HTTP) using the provided DDL. +func runBigIntOverflowTest(t *testing.T, ddl string, cases []struct { + name string + value *big.Int +}, unsigned bool) { + t.Helper() + + t.Run("Native", func(t *testing.T) { + ctx := context.Background() + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + conn, err := clickhouse_tests.GetConnection(testSet, t, protocol, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849")) + require.NoError(t, conn.Exec(ctx, ddl)) + t.Cleanup(func() { _ = conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_issue_1849") + require.NoError(t, err) + t.Cleanup(func() { _ = batch.Abort() }) + + err = batch.Append(tc.value) + assertBigIntResult(t, err, tc.name, unsigned) + }) + } + }) + } + }) + + t.Run("Std", func(t *testing.T) { + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + db, err := clickhouse_std_tests.GetDSNConnection(testSet, protocol, useSSL, nil) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") + _, err = db.Exec(ddl) + require.NoError(t, err) + t.Cleanup(func() { _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := stdInsertOneBigInt(db, tc.value) + assertBigIntResult(t, err, tc.name, unsigned) + }) + } + }) + } + }) +} + +// assertOverflow checks that the error contains "overflow" unless the case +// name matches one of the valid boundary names. +func assertOverflow(t *testing.T, err error, name string, validNames ...string) { + t.Helper() + for _, vn := range validNames { + if name == vn { + assert.NoError(t, err) + return + } + } + assert.ErrorContains(t, err, "overflow") +} + +// assertBigIntResult checks the expected outcome for BigInt test cases. +func assertBigIntResult(t *testing.T, err error, name string, unsigned bool) { + t.Helper() + if unsigned && name == "negative_not_allowed" { + assert.ErrorContains(t, err, "negative") + return + } + switch name { + case "valid_max_boundary", "valid_min_boundary", "valid_zero": + assert.NoError(t, err) + default: + assert.ErrorContains(t, err, "overflow") + } +} + +// stdInsertOneDecimal runs a single-row INSERT through the database/sql surface +// for a Decimal column. +func stdInsertOneDecimal(db *sql.DB, value decimal.Decimal) error { + ctx := context.Background() + scope, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin: %w", err) + } + defer func() { _ = scope.Rollback() }() + + stmt, err := scope.PrepareContext(ctx, "INSERT INTO test_issue_1849") + if err != nil { + return fmt.Errorf("prepare: %w", err) + } + defer stmt.Close() + + if _, err := stmt.ExecContext(ctx, value); err != nil { + return err + } + return scope.Commit() +} + +// stdInsertOneBigInt runs a single-row INSERT through the database/sql surface +// for a BigInt column. +func stdInsertOneBigInt(db *sql.DB, value *big.Int) error { + ctx := context.Background() + scope, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin: %w", err) + } + defer func() { _ = scope.Rollback() }() + + stmt, err := scope.PrepareContext(ctx, "INSERT INTO test_issue_1849") + if err != nil { + return fmt.Errorf("prepare: %w", err) + } + defer stmt.Close() + + if _, err := stmt.ExecContext(ctx, value); err != nil { + return err + } + return scope.Commit() +}