diff --git a/driver.go b/driver.go index 6029c980..3ac51051 100644 --- a/driver.go +++ b/driver.go @@ -87,7 +87,8 @@ func RegisterWithSource(driverName string, source string, options ...DriverOptio return "", errors.New("unable to register driver, all slots have been taken") } -// Wrap takes a SQL driver and wraps it with OpenTelemetry instrumentation. +// Wrap takes an SQL driver and wraps it with OpenTelemetry instrumentation. +// It panics if there is an error when creating instruments. func Wrap(d driver.Driver, opts ...DriverOption) driver.Driver { o := driverOptions{ meterProvider: otel.GetMeterProvider(), @@ -102,14 +103,18 @@ func Wrap(d driver.Driver, opts ...DriverOption) driver.Driver { option.applyDriverOptions(&o) } - return wrapDriver(d, o) + cc, err := newConnConfig(o) + if err != nil { + panic(err) + } + + return wrapDriver(d, cc) } -func wrapDriver(d driver.Driver, o driverOptions) driver.Driver { +func wrapDriver(d driver.Driver, cc connConfig) driver.Driver { drv := otDriver{ parent: d, - connConfig: newConnConfig(o), - close: func() error { return nil }, + connConfig: cc, } if _, ok := d.(driver.DriverContext); ok { @@ -122,7 +127,7 @@ func wrapDriver(d driver.Driver, o driverOptions) driver.Driver { return struct{ driver.Driver }{drv} } -func newConnConfig(opts driverOptions) connConfig { +func newConnConfig(opts driverOptions) (connConfig, error) { meter := opts.meterProvider.Meter(instrumentationName) tracer := newMethodTracer( opts.tracerProvider.Tracer(instrumentationName, @@ -139,13 +144,17 @@ func newConnConfig(opts driverOptions) connConfig { metric.WithUnit(unitMilliseconds), metric.WithDescription(`The distribution of latencies of various calls in milliseconds`), ) - mustNoError(err) + if err != nil { + return connConfig{}, err + } callsCounter, err := meter.Int64Counter(dbSQLClientCalls, metric.WithUnit(unitDimensionless), metric.WithDescription(`The number of various calls of methods`), ) - mustNoError(err) + if err != nil { + return connConfig{}, err + } latencyRecorder := newMethodRecorder(latencyMsHistogram.Record, callsCounter.Add, opts.defaultAttributes...) @@ -161,7 +170,7 @@ func newConnConfig(opts driverOptions) connConfig { queryFuncMiddlewares: makeQueryerContextMiddlewares(latencyRecorder, tracerOrNil(tracer, opts.trace.AllowRoot), newQueryConfig(opts, metricMethodStmtQuery, traceMethodStmtQuery)), queryContextFuncMiddlewares: makeQueryerContextMiddlewares(latencyRecorder, tracer, newQueryConfig(opts, metricMethodStmtQuery, traceMethodStmtQuery)), }), - } + }, nil } var _ driver.Driver = (*otDriver)(nil) @@ -184,7 +193,11 @@ func (d otDriver) Open(name string) (driver.Conn, error) { } func (d otDriver) Close() error { - return d.close() + if d.close != nil { + return d.close() + } + + return nil } func (d otDriver) OpenConnector(name string) (driver.Connector, error) { diff --git a/driver_test.go b/driver_test.go index 3352ca34..a35cc7b8 100644 --- a/driver_test.go +++ b/driver_test.go @@ -259,6 +259,22 @@ func TestWrap_DriverContext_CloseError(t *testing.T) { assert.Equal(t, expectedError, err) } +func TestWrap_Panic(t *testing.T) { + t.Parallel() + + parent := driverOpenFunc(func(string) (driver.Conn, error) { + return nil, errors.New("open error") + }) + + meterProviderOption := otelsql.WithMeterProvider( + oteltest.NewMeterProviderWithError(assert.AnError), + ) + + assert.PanicsWithValue(t, assert.AnError, func() { + _ = otelsql.Wrap(parent, meterProviderOption) + }) +} + func Test_Open_Error(t *testing.T) { t.Parallel() diff --git a/errors.go b/errors.go deleted file mode 100644 index f96f9a54..00000000 --- a/errors.go +++ /dev/null @@ -1,15 +0,0 @@ -package otelsql - -import "go.opentelemetry.io/otel" - -func handleErr(err error) { - if err != nil { - otel.Handle(err) - } -} - -func mustNoError(err error) { - if err != nil { - panic(err) - } -} diff --git a/errors_internal_test.go b/errors_internal_test.go deleted file mode 100644 index 652686c7..00000000 --- a/errors_internal_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package otelsql - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHandleError(t *testing.T) { - t.Parallel() - - assert.Panics(t, func() { - mustNoError(errors.New("error")) - }) - - assert.NotPanics(t, func() { - mustNoError(nil) - }) - - assert.NotPanics(t, func() { - handleErr(nil) - }) - - assert.NotPanics(t, func() { - handleErr(assert.AnError) - }) -} diff --git a/internal/test/oteltest/errors.go b/internal/test/oteltest/errors.go new file mode 100644 index 00000000..033b76be --- /dev/null +++ b/internal/test/oteltest/errors.go @@ -0,0 +1,92 @@ +package oteltest + +import ( + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/embedded" +) + +type errorMeterProvider struct { + embedded.MeterProvider + + Error error +} + +// NewMeterProviderWithError returns a new [metric.MeterProvider] that always +// returns the given error. +func NewMeterProviderWithError(e error) metric.MeterProvider { + return &errorMeterProvider{ + Error: e, + } +} + +func (e *errorMeterProvider) Meter(string, ...metric.MeterOption) metric.Meter { + return &errorMeter{ + Error: e.Error, + } +} + +type errorMeter struct { + embedded.Meter + + Error error +} + +func (e *errorMeter) Int64Counter(string, ...metric.Int64CounterOption) (metric.Int64Counter, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64UpDownCounter(string, ...metric.Int64UpDownCounterOption) (metric.Int64UpDownCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64Histogram(string, ...metric.Int64HistogramOption) (metric.Int64Histogram, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64Gauge(string, ...metric.Int64GaugeOption) (metric.Int64Gauge, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64ObservableCounter(string, ...metric.Int64ObservableCounterOption) (metric.Int64ObservableCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64ObservableUpDownCounter(string, ...metric.Int64ObservableUpDownCounterOption) (metric.Int64ObservableUpDownCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Int64ObservableGauge(string, ...metric.Int64ObservableGaugeOption) (metric.Int64ObservableGauge, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64Counter(string, ...metric.Float64CounterOption) (metric.Float64Counter, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64UpDownCounter(string, ...metric.Float64UpDownCounterOption) (metric.Float64UpDownCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64Histogram(string, ...metric.Float64HistogramOption) (metric.Float64Histogram, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64Gauge(string, ...metric.Float64GaugeOption) (metric.Float64Gauge, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64ObservableCounter(string, ...metric.Float64ObservableCounterOption) (metric.Float64ObservableCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64ObservableUpDownCounter(string, ...metric.Float64ObservableUpDownCounterOption) (metric.Float64ObservableUpDownCounter, error) { + return nil, e.Error +} + +func (e *errorMeter) Float64ObservableGauge(string, ...metric.Float64ObservableGaugeOption) (metric.Float64ObservableGauge, error) { + return nil, e.Error +} + +func (e *errorMeter) RegisterCallback(metric.Callback, ...metric.Observable) (metric.Registration, error) { + return nil, e.Error +} diff --git a/stats.go b/stats.go index d5c09848..5765411f 100644 --- a/stats.go +++ b/stats.go @@ -75,56 +75,72 @@ func recordStats( metric.WithUnit(unitDimensionless), metric.WithDescription("Count of open connections in the pool"), ) - handleErr(err) + if err != nil { + return err + } idleConnections, err = meter.Int64ObservableGauge( dbSQLConnectionsIdle, metric.WithUnit(unitDimensionless), metric.WithDescription("Count of idle connections in the pool"), ) - handleErr(err) + if err != nil { + return err + } activeConnections, err = meter.Int64ObservableGauge( dbSQLConnectionsActive, metric.WithUnit(unitDimensionless), metric.WithDescription("Count of active connections in the pool"), ) - handleErr(err) + if err != nil { + return err + } waitCount, err = meter.Int64ObservableCounter( dbSQLConnectionsWaitCount, metric.WithUnit(unitDimensionless), metric.WithDescription("The total number of connections waited for"), ) - handleErr(err) + if err != nil { + return err + } waitDuration, err = meter.Float64ObservableCounter( dbSQLConnectionsWaitDuration, metric.WithUnit(unitMilliseconds), metric.WithDescription("The total time blocked waiting for a new connection"), ) - handleErr(err) + if err != nil { + return err + } idleClosed, err = meter.Int64ObservableCounter( dbSQLConnectionsIdleClosed, metric.WithUnit(unitDimensionless), metric.WithDescription("The total number of connections closed due to SetMaxIdleConns"), ) - handleErr(err) + if err != nil { + return err + } idleTimeClosed, err = meter.Int64ObservableCounter( dbSQLConnectionsIdleTimeClosed, metric.WithUnit(unitDimensionless), metric.WithDescription("The total number of connections closed due to SetConnMaxIdleTime"), ) - handleErr(err) + if err != nil { + return err + } lifetimeClosed, err = meter.Int64ObservableCounter( dbSQLConnectionsLifetimeClosed, metric.WithUnit(unitDimensionless), metric.WithDescription("The total number of connections closed due to SetConnMaxLifetime"), ) - handleErr(err) + if err != nil { + return err + } _, err = meter.RegisterCallback(func(_ context.Context, obs metric.Observer) error { lock.Lock() diff --git a/stats_test.go b/stats_test.go index 0b18ddb0..d1549434 100644 --- a/stats_test.go +++ b/stats_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" semconv "go.opentelemetry.io/otel/semconv/v1.20.0" @@ -12,6 +13,20 @@ import ( "go.nhat.io/otelsql/internal/test/sqlmock" ) +func TestRecordStatsError(t *testing.T) { + t.Parallel() + + oteltest.New().Run(t, func(sc oteltest.SuiteContext) { + db, err := newDB(sc.DatabaseDSN()) + require.NoError(t, err) + + err = otelsql.RecordStats(db, otelsql.WithMeterProvider( + oteltest.NewMeterProviderWithError(assert.AnError), + )) + require.ErrorIs(t, err, assert.AnError) + }) +} + func TestRecordStats(t *testing.T) { t.Parallel()