Skip to content

Commit

Permalink
array arg wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Oct 16, 2024
1 parent 2983a05 commit a7024e3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
35 changes: 28 additions & 7 deletions impl/arrays.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,48 @@ import (
"github.com/lib/pq"
)

func WrapForArray(a any) interface {
type ValuerScanner interface {
driver.Valuer
sql.Scanner
} {
}

func WrapArray(a any) ValuerScanner {
// TODO replace with own implementation
return pq.Array(a)
}

func ShouldWrapForArray(v reflect.Value) bool {
func NeedsArrayWrappingForScanning(v reflect.Value) bool {
t := v.Type()
switch t.Kind() {
case reflect.Slice:
if t.Elem() == typeOfByte {
return false // Byte slices are scanned as strings
}
return !v.Addr().Type().Implements(typeOfSQLScanner)
// Byte slices are scanned as strings
return t.Elem() != typeOfByte && !v.Addr().Type().Implements(typeOfSQLScanner)
case reflect.Array:
return !v.Addr().Type().Implements(typeOfSQLScanner)
}
return false
}

func NeedsArrayWrappingForArg(arg any) bool {
t := reflect.TypeOf(arg)
switch t.Kind() {
case reflect.Slice:
// Byte slices are interpreted as strings
return t.Elem() != typeOfByte && !t.Implements(typeOfDriverValuer)
case reflect.Array:
return !t.Implements(typeOfDriverValuer)
}
return false
}

func WrapArrayArgs(args []any) {
for i, arg := range args {
if NeedsArrayWrappingForArg(arg) {
args[i] = WrapArray(arg)
}
}
}

// type ArrayScanner struct {
// Dest reflect.Value
// }
Expand Down
6 changes: 3 additions & 3 deletions impl/arrays_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/domonda/go-types/nullable"
)

func TestShouldWrapForArray(t *testing.T) {
func TestNeedsArrayWrappingForScanning(t *testing.T) {
tests := []struct {
v reflect.Value
want bool
Expand All @@ -27,8 +27,8 @@ func TestShouldWrapForArray(t *testing.T) {
{v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true},
}
for _, tt := range tests {
if got := ShouldWrapForArray(tt.v); got != tt.want {
t.Errorf("shouldWrapArray() = %v, want %v", got, tt.want)
if got := NeedsArrayWrappingForScanning(tt.v); got != tt.want {
t.Errorf("NeedsArrayWrappingForScanning() = %v, want %v", got, tt.want)
}
}
}
14 changes: 8 additions & 6 deletions impl/foreachrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package impl
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"time"
Expand All @@ -11,12 +12,13 @@ import (
)

var (
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
typeOfTime = reflect.TypeOf(time.Time{})
typeOfByte = reflect.TypeOf(byte(0))
typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem()
typeOfError = reflect.TypeFor[error]()
typeOfContext = reflect.TypeFor[context.Context]()
typeOfSQLScanner = reflect.TypeFor[sql.Scanner]()
typeOfDriverValuer = reflect.TypeFor[driver.Valuer]()
typeOfTime = reflect.TypeFor[time.Time]()
typeOfByte = reflect.TypeFor[byte]()
typeOfByteSlice = reflect.TypeFor[[]byte]()
)

// ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row.
Expand Down
4 changes: 2 additions & 2 deletions impl/reflectstruct.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
// If field is a slice or array that does not implement sql.Scanner
// and it's not a string scannable []byte type underneath
// then wrap it with WrapForArray to make it scannable
if ShouldWrapForArray(fieldValue) {
pointer = WrapForArray(pointer)
if NeedsArrayWrappingForScanning(fieldValue) {
pointer = WrapArray(pointer)
}
pointers[colIndex] = pointer
}
Expand Down
3 changes: 3 additions & 0 deletions pqconn/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func (conn *connection) Now() (time.Time, error) {
}

func (conn *connection) Exec(query string, args ...any) error {
impl.WrapArrayArgs(args)
_, err := conn.db.ExecContext(conn.ctx, query, args...)
return wrapError(err, query, argFmt, args)
}
Expand Down Expand Up @@ -162,6 +163,7 @@ func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns
}

func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner {
impl.WrapArrayArgs(args)
rows, err := conn.db.QueryContext(conn.ctx, query, args...)
if err != nil {
err = wrapError(err, query, argFmt, args)
Expand All @@ -171,6 +173,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner {
}

func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner {
impl.WrapArrayArgs(args)
rows, err := conn.db.QueryContext(conn.ctx, query, args...)
if err != nil {
err = wrapError(err, query, argFmt, args)
Expand Down
3 changes: 3 additions & 0 deletions pqconn/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (conn *transaction) Now() (time.Time, error) {
}

func (conn *transaction) Exec(query string, args ...any) error {
impl.WrapArrayArgs(args)
_, err := conn.tx.Exec(query, args...)
return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)
}
Expand Down Expand Up @@ -117,6 +118,7 @@ func (conn *transaction) InsertStructs(table string, rowStructs any, ignoreColum
}

func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner {
impl.WrapArrayArgs(args)
rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...)
if err != nil {
err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)
Expand All @@ -126,6 +128,7 @@ func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner {
}

func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner {
impl.WrapArrayArgs(args)
rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...)
if err != nil {
err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)
Expand Down

0 comments on commit a7024e3

Please sign in to comment.