diff --git a/connection_instrumented.go b/connection_instrumented.go index 92a2dda9..a6504d6e 100644 --- a/connection_instrumented.go +++ b/connection_instrumented.go @@ -3,9 +3,8 @@ package pop import ( "database/sql" "database/sql/driver" + "context" "fmt" - "slices" - "sync" "github.com/jmoiron/sqlx" "github.com/luna-duclos/instrumentedsql" @@ -16,12 +15,8 @@ import ( "github.com/gobuffalo/pop/v6/logging" ) -const instrumentedDriverName = "instrumented-sql-driver" - -var sqlDriverLock = sync.Mutex{} - -func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (driverName, dialect string, err error) { - driverName = defaultDriverName +func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drv driver.Driver, dialect string, err error) { + driverName := defaultDriverName if deets.Driver != "" { driverName = deets.Driver } @@ -35,8 +30,8 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive ) } - // If instrumentation is disabled, we just return the driver name we got (e.g. "pgx"). - return driverName, dialect, nil + // If instrumentation is disabled, return nil driver to signal non-instrumented path. + return nil, dialect, nil } if len(deets.InstrumentedDriverOptions) == 0 { @@ -46,41 +41,23 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive ) } - var dr driver.Driver - var newDriverName string switch CanonicalDialect(driverName) { case nameCockroach: fallthrough case namePostgreSQL: - dr = new(pgx.Driver) - newDriverName = instrumentedDriverName + "-" + namePostgreSQL + drv = new(pgx.Driver) case nameMariaDB: fallthrough case nameMySQL: - dr = mysqld.MySQLDriver{} - newDriverName = instrumentedDriverName + "-" + nameMySQL + drv = mysqld.MySQLDriver{} case nameSQLite3: - var err error - dr, err = newSQLiteDriver() + drv, err = newSQLiteDriver() if err != nil { - return "", "", err + return nil, "", err } - newDriverName = instrumentedDriverName + "-" + nameSQLite3 - } - - sqlDriverLock.Lock() - defer sqlDriverLock.Unlock() - - var found bool - if slices.Contains(sql.Drivers(), newDriverName) { - found = true - } - - if !found { - sql.Register(newDriverName, instrumentedsql.WrapDriver(dr, deets.InstrumentedDriverOptions...)) } - return newDriverName, dialect, nil + return instrumentedsql.WrapDriver(drv, deets.InstrumentedDriverOptions...), dialect, nil } // openPotentiallyInstrumentedConnection first opens a raw SQL connection and then wraps it with `sqlx`. @@ -90,15 +67,37 @@ func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (drive // a custom driver name when using instrumentation, this detection would fail // otherwise. func openPotentiallyInstrumentedConnection(c dialect, dsn string) (*sqlx.DB, error) { - driverName, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver()) + drv, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver()) if err != nil { return nil, err } - con, err := sql.Open(driverName, dsn) - if err != nil { - return nil, fmt.Errorf("could not open database connection: %w", err) + var con *sql.DB + if drv != nil { + // Use sql.OpenDB with the per-connection wrapped driver instead of + // sql.Register + sql.Open, which only registers one driver per driver type + // regardless of per-connection options. + con = sql.OpenDB(&driverConnector{drv, dsn}) + } else { + con, err = sql.Open(c.DefaultDriver(), dsn) + if err != nil { + return nil, fmt.Errorf("could not open database connection: %w", err) + } } return sqlx.NewDb(con, dialect), nil } + +// driverConnector wraps a driver.Driver with a DSN to implement driver.Connector. +type driverConnector struct { + driver driver.Driver + dsn string +} + +func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { + return dc.driver.Open(dc.dsn) +} + +func (dc *driverConnector) Driver() driver.Driver { + return dc.driver +}