Skip to content

Commit

Permalink
fix: allow LIMIT 0 for SELECT queries
Browse files Browse the repository at this point in the history
This commit enables the correct queries of type `SELECT ... LIMIT 0`.
Before that, the limit-clause wasn't applied to the query.
  • Loading branch information
ygabuev committed Oct 20, 2023
1 parent 8a43835 commit 74405b7
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 23 deletions.
53 changes: 53 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ func TestDB(t *testing.T) {
{testNilModel},
{testSelectScan},
{testSelectCount},
{testSelectLimit},
{testSelectMap},
{testSelectMapSlice},
{testSelectStruct},
Expand Down Expand Up @@ -348,6 +349,37 @@ func testSelectCount(t *testing.T, db *bun.DB) {
require.Equal(t, 3, count)
}

func testSelectLimit(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
return
}

values := db.NewValues(&[]map[string]interface{}{
{"num": 1},
{"num": 2},
{"num": 3},
})

q := db.NewSelect().
With("t", values).
Column("t.num").
TableExpr("t")

var nums []int
err := q.Limit(5).Scan(ctx, &nums)
require.NoError(t, err)
require.Equal(t, 3, len(nums))

err = q.Limit(2).Scan(ctx, &nums)
require.NoError(t, err)
require.Equal(t, 2, len(nums))

err = q.Limit(0).Scan(ctx, &nums)
require.NoError(t, err)
require.Equal(t, 0, len(nums))
}

func testSelectMap(t *testing.T, db *bun.DB) {
var m map[string]interface{}
err := db.NewSelect().
Expand Down Expand Up @@ -1344,6 +1376,9 @@ func testScanAndCount(t *testing.T, db *bun.DB) {
})

t.Run("no limit", func(t *testing.T) {
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

src := []Model{
{Str: "str1"},
{Str: "str2"},
Expand All @@ -1357,6 +1392,24 @@ func testScanAndCount(t *testing.T, db *bun.DB) {
require.Equal(t, 2, count)
require.Equal(t, 2, len(dest))
})

t.Run("limit 0", func(t *testing.T) {
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

src := []Model{
{Str: "str1"},
{Str: "str2"},
}
_, err = db.NewInsert().Model(&src).Exec(ctx)
require.NoError(t, err)

var dest []Model
count, err := db.NewSelect().Model(&dest).Limit(0).ScanAndCount(ctx)
require.NoError(t, err)
require.Equal(t, 2, count)
require.Equal(t, 0, len(dest))
})
}

func testEmbedModelValue(t *testing.T, db *bun.DB) {
Expand Down
46 changes: 23 additions & 23 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type SelectQuery struct {
group []schema.QueryWithArgs
having []schema.QueryWithArgs
order []schema.QueryWithArgs
limit int32
limit *int32
offset int32
selFor schema.QueryWithArgs

Expand Down Expand Up @@ -313,7 +313,11 @@ func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery
}

func (q *SelectQuery) Limit(n int) *SelectQuery {
q.limit = int32(n)
if n >= 0 {
l := int32(n)
q.limit = &l
}

return q
}

Expand Down Expand Up @@ -611,29 +615,29 @@ func (q *SelectQuery) appendQuery(
}

if fmter.Dialect().Features().Has(feature.OffsetFetch) {
if q.limit > 0 && q.offset > 0 {
if q.limit != nil && q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)

b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = strconv.AppendInt(b, int64(*q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.limit > 0 {
} else if q.limit != nil {
b = append(b, " OFFSET 0 ROWS"...)

b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = strconv.AppendInt(b, int64(*q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
}
} else {
if q.limit > 0 {
if q.limit != nil {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = strconv.AppendInt(b, int64(*q.limit), 10)
}
if q.offset > 0 {
b = append(b, " OFFSET "...)
Expand Down Expand Up @@ -958,20 +962,18 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{})
var mu sync.Mutex
var firstErr error

if q.limit >= 0 {
wg.Add(1)
go func() {
defer wg.Done()
wg.Add(1)
go func() {
defer wg.Done()

if err := q.Scan(ctx, dest...); err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
if err := q.Scan(ctx, dest...); err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
}()
}
mu.Unlock()
}
}()

wg.Add(1)
go func() {
Expand All @@ -995,9 +997,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{})
func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) {
var firstErr error

if q.limit >= 0 {
firstErr = q.Scan(ctx, dest...)
}
firstErr = q.Scan(ctx, dest...)

count, err := q.Count(ctx)
if err != nil && firstErr == nil {
Expand Down

0 comments on commit 74405b7

Please sign in to comment.