Skip to content

Commit

Permalink
replace Connection.Insert methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Dec 1, 2024
1 parent e019dac commit ce32db2
Show file tree
Hide file tree
Showing 15 changed files with 329 additions and 228 deletions.
43 changes: 10 additions & 33 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ type (
OnUnlistenFunc func(channel string)
)

// PlaceholderFormatter is an interface for formatting query parameter placeholders
// implemented by database connections.
type PlaceholderFormatter interface {
// Placeholder formats a query parameter placeholder
// for the paramIndex starting at zero.
Placeholder(paramIndex int) string
}

// Connection represents a database connection or transaction
type Connection interface {
PlaceholderFormatter

// Context that all connection operations use.
// See also WithContext.
Context() context.Context
Expand Down Expand Up @@ -53,39 +63,6 @@ type Connection interface {
// Exec executes a query with optional args.
Exec(query string, args ...any) error

// Insert a new row into table using the values.
Insert(table string, values Values) error

// InsertUnique inserts a new row into table using the passed values
// or does nothing if the onConflict statement applies.
// Returns if a row was inserted.
InsertUnique(table string, values Values, onConflict string) (inserted bool, err error)

// InsertReturning inserts a new row into table using values
// and returns values from the inserted row listed in returning.
InsertReturning(table string, values Values, returning string) RowScanner

// InsertStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error

// InsertStructs inserts a slice or array of structs
// as new rows into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
//
// TODO optimized version with single query if possible
// split into multiple queries depending or maxArgs for query
InsertStructs(table string, rowStructs any, ignoreColumns ...ColumnFilter) error

// InsertUniqueStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
// Does nothing if the onConflict statement applies
// and returns if a row was inserted.
InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error)

// Update table rows(s) with values using the where statement with passed in args starting at $1.
Update(table string, values Values, where string, args ...any) error

Expand Down
39 changes: 39 additions & 0 deletions db/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package db

import (
"fmt"

"github.com/domonda/go-sqldb"
"github.com/domonda/go-sqldb/impl"
)

// // WrapNonNilErrorWithQuery wraps non nil errors with a formatted query
// // if the error was not already wrapped with a query.
// // If the passed error is nil, then nil will be returned.
// func WrapNonNilErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error {
// if err == nil {
// return nil
// }
// var wrapped errWithQuery
// if errors.As(err, &wrapped) {
// return err // already wrapped
// }
// return errWithQuery{err, query, args, argFmt}
// }

func wrapErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error {
return errWithQuery{err, query, args, argFmt}
}

type errWithQuery struct {
err error
query string
args []any
argFmt sqldb.PlaceholderFormatter
}

func (e errWithQuery) Unwrap() error { return e.err }

func (e errWithQuery) Error() string {
return fmt.Sprintf("%s from query: %s", e.err, impl.FormatQuery2(e.query, e.argFmt, e.args...))
}
179 changes: 179 additions & 0 deletions db/insert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package db

import (
"context"
"fmt"
"reflect"
"strings"

"github.com/domonda/go-sqldb"
"github.com/domonda/go-sqldb/impl"
)

func writeInsertQuery(w *strings.Builder, table string, names []string, format sqldb.PlaceholderFormatter) {
fmt.Fprintf(w, `INSERT INTO %s(`, table)
for i, name := range names {
if i > 0 {
w.WriteByte(',')
}
w.WriteByte('"')
w.WriteString(name)
w.WriteByte('"')
}
w.WriteString(`) VALUES(`)
for i := range names {
if i > 0 {
w.WriteByte(',')
}
w.WriteString(format.Placeholder(i))
}
w.WriteByte(')')
}

func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) {
v := reflect.ValueOf(rowStruct)
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
switch {
case v.Kind() == reflect.Ptr && v.IsNil():
return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table)
case v.Kind() != reflect.Struct:
return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct)
}

columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly))
return columns, vals, nil
}

// Insert a new row into table using the values.
func Insert(ctx context.Context, table string, values sqldb.Values) error {
if len(values) == 0 {
return fmt.Errorf("Insert into table %s: no values", table)
}
conn := Conn(ctx)

var query strings.Builder
names, vals := values.Sorted()
writeInsertQuery(&query, table, names, conn)

err := conn.Exec(query.String(), vals...)
if err != nil {
return wrapErrorWithQuery(err, query.String(), vals, conn)
}
return nil
}

// InsertUnique inserts a new row into table using the passed values
// or does nothing if the onConflict statement applies.
// Returns if a row was inserted.
func InsertUnique(ctx context.Context, table string, values sqldb.Values, onConflict string) (inserted bool, err error) {
if len(values) == 0 {
return false, fmt.Errorf("InsertUnique into table %s: no values", table)
}
conn := Conn(ctx)

if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") {
onConflict = onConflict[1 : len(onConflict)-1]
}

var query strings.Builder
names, vals := values.Sorted()
writeInsertQuery(&query, table, names, conn)
fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict)

err = conn.QueryRow(query.String(), vals...).Scan(&inserted)
err = sqldb.ReplaceErrNoRows(err, nil)
if err != nil {
return false, wrapErrorWithQuery(err, query.String(), vals, conn)
}
return inserted, err
}

// InsertReturning inserts a new row into table using values
// and returns values from the inserted row listed in returning.
func InsertReturning(ctx context.Context, table string, values sqldb.Values, returning string) sqldb.RowScanner {
if len(values) == 0 {
return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table))
}
conn := Conn(ctx)

var query strings.Builder
names, vals := values.Sorted()
writeInsertQuery(&query, table, names, conn)
query.WriteString(" RETURNING ")
query.WriteString(returning)
return conn.QueryRow(query.String(), vals...) // TODO wrap error with query
}

// InsertStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
conn := Conn(ctx)
columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns)
if err != nil {
return err
}

var query strings.Builder
writeInsertQuery(&query, table, columns, conn)

err = conn.Exec(query.String(), vals...)
if err != nil {
return wrapErrorWithQuery(err, query.String(), vals, conn)
}
return nil
}

// InsertUniqueStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
// Does nothing if the onConflict statement applies
// and returns if a row was inserted.
func InsertUniqueStruct(ctx context.Context, table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) {
conn := Conn(ctx)
columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns)
if err != nil {
return false, err
}

if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") {
onConflict = onConflict[1 : len(onConflict)-1]
}

var query strings.Builder
writeInsertQuery(&query, table, columns, conn)
fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict)

err = conn.QueryRow(query.String(), vals...).Scan(&inserted)
err = sqldb.ReplaceErrNoRows(err, nil)
if err != nil {
return false, wrapErrorWithQuery(err, query.String(), vals, conn)
}
return inserted, err
}

// InsertStructs inserts a slice or array of structs
// as new rows into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
//
// TODO optimized version with single query if possible
// split into multiple queries depending or maxArgs for query
func InsertStructs(ctx context.Context, table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error {
v := reflect.ValueOf(rowStructs)
if k := v.Type().Kind(); k != reflect.Slice && k != reflect.Array {
return fmt.Errorf("InsertStructs expects a slice or array as rowStructs, got %T", rowStructs)
}
numRows := v.Len()
return Transaction(ctx, func(ctx context.Context) error {
for i := 0; i < numRows; i++ {
err := InsertStruct(ctx, table, v.Index(i).Interface(), ignoreColumns...)
if err != nil {
return err
}
}
return nil
})
}
7 changes: 0 additions & 7 deletions db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,3 @@ func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (ro
}
return rows, nil
}

// InsertStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
return Conn(ctx).InsertStruct(table, rowStruct, ignoreColumns...)
}
29 changes: 5 additions & 24 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)

Expand Down Expand Up @@ -208,38 +209,18 @@ func (e connectionWithError) Config() *Config {
return &Config{Err: e.err}
}

func (e connectionWithError) ValidateColumnName(name string) error {
return e.err
}

func (e connectionWithError) Exec(query string, args ...any) error {
return e.err
}

func (e connectionWithError) Insert(table string, values Values) error {
return e.err
func (e connectionWithError) Placeholder(paramIndex int) string {
return fmt.Sprintf("$%d", paramIndex+1)
}

func (e connectionWithError) InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) {
return false, e.err
}

func (e connectionWithError) InsertReturning(table string, values Values, returning string) RowScanner {
return RowScannerWithError(e.err)
}

func (e connectionWithError) InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error {
func (e connectionWithError) ValidateColumnName(name string) error {
return e.err
}

func (e connectionWithError) InsertStructs(table string, rowStructs any, ignoreColumns ...ColumnFilter) error {
func (e connectionWithError) Exec(query string, args ...any) error {
return e.err
}

func (e connectionWithError) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) {
return false, e.err
}

func (e connectionWithError) Update(table string, values Values, where string, args ...any) error {
return e.err
}
Expand Down
28 changes: 4 additions & 24 deletions impl/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (conn *connection) Config() *sqldb.Config {
return conn.config
}

func (conn *connection) Placeholder(paramIndex int) string {
return fmt.Sprintf(conn.argFmt, paramIndex+1)
}

func (conn *connection) ValidateColumnName(name string) error {
return conn.validateColumnName(name)
}
Expand All @@ -87,30 +91,6 @@ func (conn *connection) Exec(query string, args ...any) error {
return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args)
}

func (conn *connection) Insert(table string, columValues sqldb.Values) error {
return Insert(conn, table, conn.argFmt, columValues)
}

func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) {
return InsertUnique(conn, table, conn.argFmt, values, onConflict)
}

func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner {
return InsertReturning(conn, table, conn.argFmt, values, returning)
}

func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns)
}

func (conn *connection) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error {
return InsertStructs(conn, table, rowStructs, ignoreColumns...)
}

func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) {
return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns)
}

func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error {
return Update(conn, table, values, where, conn.argFmt, args)
}
Expand Down
7 changes: 5 additions & 2 deletions impl/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
// if the error was not already wrapped with a query.
// If the passed error is nil, then nil will be returned.
func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error {
if err == nil {
return nil
}
var wrapped errWithQuery
if err == nil || errors.As(err, &wrapped) {
return err
if errors.As(err, &wrapped) {
return err // already wrapped
}
return errWithQuery{err, query, argFmt, args}
}
Expand Down
Loading

0 comments on commit ce32db2

Please sign in to comment.