Skip to content

Commit

Permalink
remove non-context methods
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Jan 27, 2024
1 parent 2b8cde0 commit 9b9e3cb
Show file tree
Hide file tree
Showing 30 changed files with 177 additions and 411 deletions.
28 changes: 16 additions & 12 deletions batch_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dynamo

import (
"context"
"testing"
"time"
)
Expand All @@ -12,6 +13,7 @@ func TestBatchGetWrite(t *testing.T) {
t.Skip(offlineSkipMsg)
}
table := testDB.Table(testTable)
ctx := context.TODO()

items := make([]interface{}, batchSize)
widgets := make(map[int]widget)
Expand All @@ -29,7 +31,7 @@ func TestBatchGetWrite(t *testing.T) {
}

var wcc ConsumedCapacity
wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run()
wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run(ctx)
if wrote != batchSize {
t.Error("unexpected wrote:", wrote, "≠", batchSize)
}
Expand All @@ -48,7 +50,7 @@ func TestBatchGetWrite(t *testing.T) {
Project("UserID", "Time").
Consistent(true).
ConsumedCapacity(&cc).
All(&results)
All(ctx, &results)
if err != nil {
t.Error("unexpected error:", err)
}
Expand All @@ -73,7 +75,7 @@ func TestBatchGetWrite(t *testing.T) {

// delete both
wrote, err = table.Batch("UserID", "Time").Write().
Delete(keys...).Run()
Delete(keys...).Run(ctx)
if wrote != batchSize {
t.Error("unexpected wrote:", wrote, "≠", batchSize)
}
Expand All @@ -86,7 +88,7 @@ func TestBatchGetWrite(t *testing.T) {
err = table.Batch("UserID", "Time").
Get(keys...).
Consistent(true).
All(&results)
All(ctx, &results)
if err != ErrNotFound {
t.Error("expected ErrNotFound, got", err)
}
Expand All @@ -100,15 +102,16 @@ func TestBatchGetEmptySets(t *testing.T) {
t.Skip(offlineSkipMsg)
}
table := testDB.Table(testTable)
ctx := context.TODO()

now := time.Now().UnixNano() / 1000000000
id := int(now)
entry := widget{UserID: id, Time: time.Now()}
if err := table.Put(entry).Run(); err != nil {
if err := table.Put(entry).Run(ctx); err != nil {
panic(err)
}
entry2 := widget{UserID: id + batchSize*2, Time: entry.Time}
if err := table.Put(entry2).Run(); err != nil {
if err := table.Put(entry2).Run(ctx); err != nil {
panic(err)
}

Expand All @@ -118,20 +121,20 @@ func TestBatchGetEmptySets(t *testing.T) {
}

results := []widget{}
err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results)
err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results)
if err != nil {
t.Error(err)
}
if len(results) != 2 {
t.Error("batch get empty set, unexpected length:", len(results), "want:", 2)
}

if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(); err != nil {
if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(ctx); err != nil {
panic(err)
}

results = []widget{}
err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results)
err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results)
if err != nil {
t.Error(err)
}
Expand All @@ -140,7 +143,7 @@ func TestBatchGetEmptySets(t *testing.T) {
}

results = []widget{}
err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(&results)
err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(ctx, &results)
if err != ErrNotFound {
t.Error(err)
}
Expand All @@ -150,14 +153,15 @@ func TestBatchGetEmptySets(t *testing.T) {
}

func TestBatchEmptyInput(t *testing.T) {
ctx := context.TODO()
table := testDB.Table(testTable)
var out []any
err := table.Batch("UserID", "Time").Get().All(&out)
err := table.Batch("UserID", "Time").Get().All(ctx, &out)
if err != ErrNoInput {
t.Error("unexpected error", err)
}

_, err = table.Batch("UserID", "Time").Write().Run()
_, err = table.Batch("UserID", "Time").Write().Run(ctx)
if err != ErrNoInput {
t.Error("unexpected error", err)
}
Expand Down
20 changes: 3 additions & 17 deletions batchget.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,9 @@ func (bg *BatchGet) ConsumedCapacity(cc *ConsumedCapacity) *BatchGet {
}

// All executes this request and unmarshals all results to out, which must be a pointer to a slice.
func (bg *BatchGet) All(out interface{}) error {
func (bg *BatchGet) All(ctx context.Context, out interface{}) error {
iter := newBGIter(bg, unmarshalAppendTo(out), bg.err)
for iter.Next(out) {
}
return iter.Err()
}

// AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice.
func (bg *BatchGet) AllWithContext(ctx context.Context, out interface{}) error {
iter := newBGIter(bg, unmarshalAppendTo(out), bg.err)
for iter.NextWithContext(ctx, out) {
for iter.Next(ctx, out) {
}
return iter.Err()
}
Expand Down Expand Up @@ -216,13 +208,7 @@ func newBGIter(bg *BatchGet, fn unmarshalFunc, err error) *bgIter {

// Next tries to unmarshal the next result into out.
// Returns false when it is complete or if it runs into an error.
func (itr *bgIter) Next(out interface{}) bool {
ctx, cancel := defaultContext()
defer cancel()
return itr.NextWithContext(ctx, out)
}

func (itr *bgIter) NextWithContext(ctx context.Context, out interface{}) bool {
func (itr *bgIter) Next(ctx context.Context, out interface{}) bool {
// stop if we have an error
if ctx.Err() != nil {
itr.err = ctx.Err()
Expand Down
8 changes: 1 addition & 7 deletions batchwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,7 @@ func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite {
// For batches with more than 25 operations, an error could indicate that
// some records have been written and some have not. Consult the wrote
// return amount to figure out which operations have succeeded.
func (bw *BatchWrite) Run() (wrote int, err error) {
ctx, cancel := defaultContext()
defer cancel()
return bw.RunWithContext(ctx)
}

func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) {
func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) {
if bw.err != nil {
return 0, bw.err
}
Expand Down
22 changes: 4 additions & 18 deletions createtable.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,7 @@ func (ct *CreateTable) SSEEncryption(enabled bool, keyID string, sseType SSEType
}

// Run creates this table or returns an error.
func (ct *CreateTable) Run() error {
ctx, cancel := defaultContext()
defer cancel()
return ct.RunWithContext(ctx)
}

// RunWithContext creates this table or returns an error.
func (ct *CreateTable) RunWithContext(ctx context.Context) error {
func (ct *CreateTable) Run(ctx context.Context) error {
if ct.err != nil {
return ct.err
}
Expand All @@ -255,18 +248,11 @@ func (ct *CreateTable) RunWithContext(ctx context.Context) error {
}

// Wait creates this table and blocks until it exists and is ready to use.
func (ct *CreateTable) Wait() error {
ctx, cancel := defaultContext()
defer cancel()
return ct.WaitWithContext(ctx)
}

// WaitWithContext creates this table and blocks until it exists and is ready to use.
func (ct *CreateTable) WaitWithContext(ctx context.Context) error {
if err := ct.RunWithContext(ctx); err != nil {
func (ct *CreateTable) Wait(ctx context.Context) error {
if err := ct.Run(ctx); err != nil {
return err
}
return ct.db.Table(ct.tableName).WaitWithContext(ctx)
return ct.db.Table(ct.tableName).Wait(ctx)
}

func (ct *CreateTable) from(rv reflect.Value) error {
Expand Down
32 changes: 8 additions & 24 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"github.com/guregu/dynamo/dynamodbiface"
"github.com/guregu/dynamo/v2/dynamodbiface"
)

// DB is a DynamoDB client.
Expand Down Expand Up @@ -99,9 +99,9 @@ func (db *DB) Client() dynamodbiface.DynamoDBAPI {
// return db
// }

func (db *DB) log(format string, v ...interface{}) {
db.logger.Logf(logging.Debug, format, v...)
}
// func (db *DB) log(format string, v ...interface{}) {
// db.logger.Logf(logging.Debug, format, v...)
// }

// ListTables is a request to list tables.
// See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_ListTables.html
Expand All @@ -115,18 +115,11 @@ func (db *DB) ListTables() *ListTables {
}

// All returns every table or an error.
func (lt *ListTables) All() ([]string, error) {
ctx, cancel := defaultContext()
defer cancel()
return lt.AllWithContext(ctx)
}

// AllWithContext returns every table or an error.
func (lt *ListTables) AllWithContext(ctx context.Context) ([]string, error) {
func (lt *ListTables) All(ctx context.Context) ([]string, error) {
var tables []string
itr := lt.Iter()
var name string
for itr.NextWithContext(ctx, &name) {
for itr.Next(ctx, &name) {
tables = append(tables, name)
}
return tables, itr.Err()
Expand All @@ -145,13 +138,7 @@ func (lt *ListTables) Iter() Iter {
return &ltIter{lt: lt}
}

func (itr *ltIter) Next(out interface{}) bool {
ctx, cancel := defaultContext()
defer cancel()
return itr.NextWithContext(ctx, out)
}

func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool {
func (itr *ltIter) Next(ctx context.Context, out interface{}) bool {
if ctx.Err() != nil {
itr.err = ctx.Err()
}
Expand Down Expand Up @@ -214,10 +201,7 @@ func (itr *ltIter) input() *dynamodb.ListTablesInput {
type Iter interface {
// Next tries to unmarshal the next result into out.
// Returns false when it is complete or if it runs into an error.
Next(out interface{}) bool
// NextWithContext tries to unmarshal the next result into out.
// Returns false when it is complete or if it runs into an error.
NextWithContext(ctx context.Context, out interface{}) bool
Next(ctx context.Context, out interface{}) bool
// Err returns the error encountered, if any.
// You should check this after Next is finished.
Err() error
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestListTables(t *testing.T) {
t.Skip(offlineSkipMsg)
}

tables, err := testDB.ListTables().All()
tables, err := testDB.ListTables().All(context.TODO())
if err != nil {
t.Error(err)
return
Expand Down
2 changes: 1 addition & 1 deletion decode_aux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/guregu/dynamo"
"github.com/guregu/dynamo/v2"
)

type Coffee struct {
Expand Down
16 changes: 2 additions & 14 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,15 @@ func (d *Delete) ConsumedCapacity(cc *ConsumedCapacity) *Delete {
}

// Run executes this delete request.
func (d *Delete) Run() error {
ctx, cancel := defaultContext()
defer cancel()
return d.RunWithContext(ctx)
}

func (d *Delete) RunWithContext(ctx context.Context) error {
func (d *Delete) Run(ctx context.Context) error {
d.returnType = "NONE"
_, err := d.run(ctx)
return err
}

// OldValue executes this delete request, unmarshaling the previous value to out.
// Returns ErrNotFound is there was no previous value.
func (d *Delete) OldValue(out interface{}) error {
ctx, cancel := defaultContext()
defer cancel()
return d.OldValueWithContext(ctx, out)
}

func (d *Delete) OldValueWithContext(ctx context.Context, out interface{}) error {
func (d *Delete) OldValue(ctx context.Context, out interface{}) error {
d.returnType = "ALL_OLD"
output, err := d.run(ctx)
switch {
Expand Down
8 changes: 5 additions & 3 deletions delete_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dynamo

import (
"context"
"reflect"
"testing"
"time"
Expand All @@ -11,6 +12,7 @@ func TestDelete(t *testing.T) {
t.Skip(offlineSkipMsg)
}
table := testDB.Table(testTable)
ctx := context.TODO()

// first, add an item to delete later
item := widget{
Expand All @@ -21,7 +23,7 @@ func TestDelete(t *testing.T) {
"color": "octarine",
},
}
err := table.Put(item).Run()
err := table.Put(item).Run(ctx)
if err != nil {
t.Error("unexpected error:", err)
}
Expand All @@ -31,15 +33,15 @@ func TestDelete(t *testing.T) {
Range("Time", item.Time).
If("Meta.'color' = ?", "octarine").
If("Msg = ?", "wrong msg").
Run()
Run(ctx)
if !IsCondCheckFailed(err) {
t.Error("expected ConditionalCheckFailedException, not", err)
}

// delete it
var old widget
var cc ConsumedCapacity
err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(&old)
err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(ctx, &old)
if err != nil {
t.Error("unexpected error:", err)
}
Expand Down
8 changes: 1 addition & 7 deletions describetable.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,7 @@ func (table Table) Describe() *DescribeTable {
}

// Run executes this request and describe the table.
func (dt *DescribeTable) Run() (Description, error) {
ctx, cancel := defaultContext()
defer cancel()
return dt.RunWithContext(ctx)
}

func (dt *DescribeTable) RunWithContext(ctx context.Context) (Description, error) {
func (dt *DescribeTable) Run(ctx context.Context) (Description, error) {
input := dt.input()

var result *dynamodb.DescribeTableOutput
Expand Down
Loading

0 comments on commit 9b9e3cb

Please sign in to comment.