diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 8055d6e4f..9b3c6a67e 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -234,6 +234,7 @@ func TestDB(t *testing.T) { {testNilModel}, {testSelectScan}, {testSelectCount}, + {testSelectLimit}, {testSelectMap}, {testSelectMapSlice}, {testSelectStruct}, @@ -347,6 +348,37 @@ func testSelectCount(t *testing.T, db *bun.DB) { require.Equal(t, 3, count) } +func testSelectLimit(t *testing.T, db *bun.DB) { + if !db.Dialect().Features().Has(feature.CTE) { + t.Skip() + return + } + + values := db.NewValues(&[]map[string]interface{}{ + {"num": 1}, + {"num": 2}, + {"num": 3}, + }) + + q := db.NewSelect(). + With("t", values). + Column("t.num"). + TableExpr("t") + + var nums []int + err := q.Limit(5).Scan(ctx, &nums) + require.NoError(t, err) + require.Equal(t, 3, len(nums)) + + err = q.Limit(2).Scan(ctx, &nums) + require.NoError(t, err) + require.Equal(t, 2, len(nums)) + + err = q.Limit(0).Scan(ctx, &nums) + require.NoError(t, err) + require.Equal(t, 0, len(nums)) +} + func testSelectMap(t *testing.T, db *bun.DB) { var m map[string]interface{} err := db.NewSelect(). @@ -1316,6 +1348,9 @@ func testScanAndCount(t *testing.T, db *bun.DB) { }) t.Run("no limit", func(t *testing.T) { + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + src := []Model{ {Str: "str1"}, {Str: "str2"}, @@ -1329,6 +1364,24 @@ func testScanAndCount(t *testing.T, db *bun.DB) { require.Equal(t, 2, count) require.Equal(t, 2, len(dest)) }) + + t.Run("limit 0", func(t *testing.T) { + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + src := []Model{ + {Str: "str1"}, + {Str: "str2"}, + } + _, err = db.NewInsert().Model(&src).Exec(ctx) + require.NoError(t, err) + + var dest []Model + count, err := db.NewSelect().Model(&dest).Limit(0).ScanAndCount(ctx) + require.NoError(t, err) + require.Equal(t, 2, count) + require.Equal(t, 0, len(dest)) + }) } func testEmbedModelValue(t *testing.T, db *bun.DB) { diff --git a/query_select.go b/query_select.go index 932cd48be..ba98092d4 100644 --- a/query_select.go +++ b/query_select.go @@ -31,7 +31,7 @@ type SelectQuery struct { group []schema.QueryWithArgs having []schema.QueryWithArgs order []schema.QueryWithArgs - limit int32 + limit *int32 offset int32 selFor schema.QueryWithArgs @@ -313,7 +313,11 @@ func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery } func (q *SelectQuery) Limit(n int) *SelectQuery { - q.limit = int32(n) + if n >= 0 { + l := int32(n) + q.limit = &l + } + return q } @@ -611,19 +615,19 @@ func (q *SelectQuery) appendQuery( } if fmter.Dialect().Features().Has(feature.OffsetFetch) { - if q.limit > 0 && q.offset > 0 { + if q.limit != nil && q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) b = append(b, " ROWS"...) b = append(b, " FETCH NEXT "...) - b = strconv.AppendInt(b, int64(q.limit), 10) + b = strconv.AppendInt(b, int64(*q.limit), 10) b = append(b, " ROWS ONLY"...) - } else if q.limit > 0 { + } else if q.limit != nil { b = append(b, " OFFSET 0 ROWS"...) b = append(b, " FETCH NEXT "...) - b = strconv.AppendInt(b, int64(q.limit), 10) + b = strconv.AppendInt(b, int64(*q.limit), 10) b = append(b, " ROWS ONLY"...) } else if q.offset > 0 { b = append(b, " OFFSET "...) @@ -631,9 +635,9 @@ func (q *SelectQuery) appendQuery( b = append(b, " ROWS"...) } } else { - if q.limit > 0 { + if q.limit != nil { b = append(b, " LIMIT "...) - b = strconv.AppendInt(b, int64(q.limit), 10) + b = strconv.AppendInt(b, int64(*q.limit), 10) } if q.offset > 0 { b = append(b, " OFFSET "...) @@ -958,20 +962,18 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) var mu sync.Mutex var firstErr error - if q.limit >= 0 { - wg.Add(1) - go func() { - defer wg.Done() + wg.Add(1) + go func() { + defer wg.Done() - if err := q.Scan(ctx, dest...); err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err } - }() - } + mu.Unlock() + } + }() wg.Add(1) go func() { @@ -995,9 +997,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) { var firstErr error - if q.limit >= 0 { - firstErr = q.Scan(ctx, dest...) - } + firstErr = q.Scan(ctx, dest...) count, err := q.Count(ctx) if err != nil && firstErr == nil {