Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 36 additions & 37 deletions connection_instrumented.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ package pop
import (
"database/sql"
"database/sql/driver"
"context"
"fmt"
"slices"
"sync"

"github.com/jmoiron/sqlx"
"github.com/luna-duclos/instrumentedsql"
Expand All @@ -16,12 +15,8 @@ import (
"github.com/gobuffalo/pop/v6/logging"
)

const instrumentedDriverName = "instrumented-sql-driver"

var sqlDriverLock = sync.Mutex{}

func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (driverName, dialect string, err error) {
driverName = defaultDriverName
func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drv driver.Driver, dialect string, err error) {
driverName := defaultDriverName
if deets.Driver != "" {
driverName = deets.Driver
}
Expand All @@ -35,8 +30,8 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive
)
}

// If instrumentation is disabled, we just return the driver name we got (e.g. "pgx").
return driverName, dialect, nil
// If instrumentation is disabled, return nil driver to signal non-instrumented path.
return nil, dialect, nil
}

if len(deets.InstrumentedDriverOptions) == 0 {
Expand All @@ -46,41 +41,23 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive
)
}

var dr driver.Driver
var newDriverName string
switch CanonicalDialect(driverName) {
case nameCockroach:
fallthrough
case namePostgreSQL:
dr = new(pgx.Driver)
newDriverName = instrumentedDriverName + "-" + namePostgreSQL
drv = new(pgx.Driver)
case nameMariaDB:
fallthrough
case nameMySQL:
dr = mysqld.MySQLDriver{}
newDriverName = instrumentedDriverName + "-" + nameMySQL
drv = mysqld.MySQLDriver{}
case nameSQLite3:
var err error
dr, err = newSQLiteDriver()
drv, err = newSQLiteDriver()
if err != nil {
return "", "", err
return nil, "", err
}
newDriverName = instrumentedDriverName + "-" + nameSQLite3
}

sqlDriverLock.Lock()
defer sqlDriverLock.Unlock()

var found bool
if slices.Contains(sql.Drivers(), newDriverName) {
found = true
}

if !found {
sql.Register(newDriverName, instrumentedsql.WrapDriver(dr, deets.InstrumentedDriverOptions...))
}

return newDriverName, dialect, nil
return instrumentedsql.WrapDriver(drv, deets.InstrumentedDriverOptions...), dialect, nil
}

// openPotentiallyInstrumentedConnection first opens a raw SQL connection and then wraps it with `sqlx`.
Expand All @@ -90,15 +67,37 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive
// a custom driver name when using instrumentation, this detection would fail
// otherwise.
func openPotentiallyInstrumentedConnection(c dialect, dsn string) (*sqlx.DB, error) {
driverName, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver())
drv, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver())
if err != nil {
return nil, err
}

con, err := sql.Open(driverName, dsn)
if err != nil {
return nil, fmt.Errorf("could not open database connection: %w", err)
var con *sql.DB
if drv != nil {
// Use sql.OpenDB with the per-connection wrapped driver instead of
// sql.Register + sql.Open, which only registers one driver per driver type
// regardless of per-connection options.
con = sql.OpenDB(&driverConnector{drv, dsn})
} else {
con, err = sql.Open(c.DefaultDriver(), dsn)
if err != nil {
return nil, fmt.Errorf("could not open database connection: %w", err)
}
}

return sqlx.NewDb(con, dialect), nil
}

// driverConnector wraps a driver.Driver with a DSN to implement driver.Connector.
type driverConnector struct {
driver driver.Driver
dsn string
}

func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
return dc.driver.Open(dc.dsn)
}

func (dc *driverConnector) Driver() driver.Driver {
return dc.driver
}