diff --git a/conn.go b/conn.go index a9cb3163f..e586a954c 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +type AfterPgxConnectFunc func(ctx context.Context, pgxconn *Conn) error + // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type ConnConfig struct { @@ -41,6 +43,11 @@ type ConnConfig struct { // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. DefaultQueryExecMode QueryExecMode + // AfterPgxConnect is utilized to modify the pgx.Conn value, such as adding options to the TypeMap value. This is + // called after the AfterConnect function is called, if supplied. If this returns an error, the connection attempt + // fails. + AfterPgxConnect AfterPgxConnectFunc + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -259,6 +266,13 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { return nil, err } + if config.AfterPgxConnect != nil { + if err := config.AfterPgxConnect(ctx, c); err != nil { + c.Close(ctx) + return nil, err + } + } + c.preparedStatements = make(map[string]*pgconn.StatementDescription) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 2be11e820..de5d05d62 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -212,6 +212,10 @@ type Map struct { // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions // should run last. i.e. Additional functions should typically be prepended not appended. TryWrapScanPlanFuncs []TryWrapScanPlanFunc + + // NilHandlers is a slice of functions that is utilized to handle the case where a driver.NamedValue instance is considered + // nil by pgx, but may need to be subsituted for a different value. + NilHandlers []NilHandler } func NewMap() *Map { @@ -1037,6 +1041,17 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { return plan.next.Scan(src, &anyArrayArrayReflect{array: reflect.ValueOf(target).Elem()}) } +type NilHandler func(*driver.NamedValue) error + +func (m *Map) TryNilHandling(nv *driver.NamedValue) error { + for _, f := range m.NilHandlers { + if err := f(nv); err != nil { + return err + } + } + return nil +} + // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { oidMemo := m.memoizedScanPlans[oid] diff --git a/stdlib/sql.go b/stdlib/sql.go index 3d65e23ad..3903d5560 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -81,6 +81,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -533,8 +534,15 @@ func (c *Conn) Ping(ctx context.Context) error { return nil } -func (c *Conn) CheckNamedValue(*driver.NamedValue) error { - // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. +func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { + // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. If the Value itself is not nil, but + // can be considered nil, check to see if an option has been set up to handle it explicitly + if nv.Value != nil && anynil.Is(nv.Value) { + tm := c.conn.TypeMap() + if err := tm.TryNilHandling(nv); err != nil { + return err + } + } return nil }