Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Seq and SeqLEK (Go 1.23 iterators) #244

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions batchget.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (bg *BatchGet) Iter() Iter {

// IterWithTable is like [BatchGet.Iter], but will update the value pointed by tablePtr after each iteration.
// This can be useful when getting from multiple tables to determine which table the latest item came from.
// See: [BatchGet.ItemTableIter] for a nicer way to do this.
//
// For example, you can utilize this iterator to read the results into different structs.
//
Expand Down
83 changes: 83 additions & 0 deletions batchget_go123.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//go:build go1.23

package dynamo

import (
"context"
"iter"
)

type ItemTableIter[V any] interface {
// Items is a sequence of item and table names.
// This is a single use iterator.
// Be sure to check for errors with Err afterwards.
Items(context.Context) iter.Seq2[V, string]
// Err must be checked after iterating.
Err() error
}

// ItemTableIter returns an iterator of (raw item, table name).
// To specify a type, use [BatchGetIter] instead.
//
// For example, you can utilize this iterator to read the results into different structs.
//
// widgetBatch := widgetsTable.Batch("UserID").Get(dynamo.Keys{userID})
// sprocketBatch := sprocketsTable.Batch("UserID").Get(dynamo.Keys{userID})
//
// iter := widgetBatch.Merge(sprocketBatch).ItemTableIter(&table)
//
// // now we will use the table iterator to unmarshal the values into their respective types
// var s sprocket
// var w widget
// for raw, table := range iter.Items {
// if table == "Widgets" {
// err := dynamo.UnmarshalItem(raw, &w)
// if err != nil {
// fmt.Println(err)
// }
// } else if table == "Sprockets" {
// err := dynamo.UnmarshalItem(raw, &s)
// if err != nil {
// fmt.Println(err)
// }
// } else {
// fmt.Printf("Unexpected Table: %s\n", table)
// }
// }
//
// if iter.Err() != nil {
// fmt.Println(iter.Err())
// }
func (bg *BatchGet) ItemTableIter() ItemTableIter[Item] {
return newBgIter2[Item](bg)
}

type bgIter2[V any] struct {
Iter
table string
}

func newBgIter2[V any](bg *BatchGet) *bgIter2[V] {
iter := new(bgIter2[V])
iter.Iter = bg.IterWithTable(&iter.table)
return iter
}

// Items is a sequence of item and table names.
// This is a single use iterator.
// Be sure to check for errors with Err afterwards.
func (iter *bgIter2[V]) Items(ctx context.Context) iter.Seq2[V, string] {
return func(yield func(V, string) bool) {
item := new(V)
for iter.Next(ctx, item) {
if !yield(*item, iter.table) {
break
}
item = new(V)
}
}
}

func BatchGetIter[V any](bg *BatchGet) ItemTableIter[V] {
return newBgIter2[V](bg)
}
6 changes: 6 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ func (itr *queryIter) Err() error {
return itr.err
}

func (itr *queryIter) SetError(err error) {
if itr.err == nil {
itr.err = err
}
}

func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
if itr.output != nil {
// if we've hit the end of our results, we can use the real LEK
Expand Down
6 changes: 6 additions & 0 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ func (itr *scanIter) Err() error {
return itr.err
}

func (itr *scanIter) SetError(err error) {
if itr.err == nil {
itr.err = err
}
}

// LastEvaluatedKey returns a key that can be used to continue this scan.
// Use with SearchLimit for best results.
func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
Expand Down
40 changes: 40 additions & 0 deletions seq_go123.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//go:build go1.23

package dynamo

import (
"context"
"iter"
)

// Seq returns an item iterator compatible with Go 1.23 `for ... range` loops.
func Seq[V any](ctx context.Context, iter Iter) iter.Seq[V] {
return func(yield func(V) bool) {
item := new(V)
for iter.Next(ctx, item) {
if !yield(*item) {
break
}
item = new(V)
}
}
}

// SeqLEK returns a LastEvaluatedKey and item iterator compatible with Go 1.23 `for ... range` loops.
func SeqLEK[V any](ctx context.Context, iter PagingIter) iter.Seq2[PagingKey, V] {
return func(yield func(PagingKey, V) bool) {
item := new(V)
for iter.Next(ctx, item) {
lek, err := iter.LastEvaluatedKey(ctx)
if err != nil {
if setter, ok := iter.(interface{ SetError(error) }); ok {
setter.SetError(err)
}
}
if !yield(lek, *item) {
break
}
item = new(V)
}
}
}
75 changes: 75 additions & 0 deletions seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//go:build go1.23

package dynamo

import (
"context"
"testing"
"time"
)

func TestSeq(t *testing.T) {
if testDB == nil {
t.Skip(offlineSkipMsg)
}
ctx := context.Background()
table := testDB.Table(testTableWidgets)

widgets := []any{
widget{
UserID: 1971,
Time: time.Date(1971, 4, 00, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
widget{
UserID: 1971,
Time: time.Date(1971, 4, 10, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
widget{
UserID: 1971,
Time: time.Date(1971, 4, 20, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
}

t.Run("prepare data", func(t *testing.T) {
if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil {
t.Fatal(err)
}
})

iter := testDB.Table(testTableWidgets).Get("UserID", 1971).Iter()
var got []*widget
var count int
for item := range Seq[*widget](ctx, iter) {
t.Log(item)
item.Count = count
got = append(got, item)
count++
}

if iter.Err() != nil {
t.Fatal(iter.Err())
}

t.Run("results match", func(t *testing.T) {
for i, item := range got {
want := widgets[i].(widget)
if !item.Time.Equal(want.Time) {
t.Error("bad result. want:", want.Time, "got:", item.Time)
}
}
})

t.Run("result item isolation", func(t *testing.T) {
// make sure that when mutating the result in the `for ... range` loop
// it only affects one item
t.Log("got", got)
for i, item := range got {
if item.Count != i {
t.Error("unexpected count. got:", item.Count, "want:", i)
}
}
})
}