From 9b9e3cb1eaec443ba52fbc6e0ae9cc48bc9589b2 Mon Sep 17 00:00:00 2001 From: Greg Date: Sat, 27 Jan 2024 19:33:04 +0900 Subject: [PATCH] remove non-context methods --- batch_test.go | 28 ++++++++++++--------- batchget.go | 20 +++------------ batchwrite.go | 8 +----- createtable.go | 22 +++------------- db.go | 32 ++++++------------------ db_test.go | 2 +- decode_aux_test.go | 2 +- delete.go | 16 ++---------- delete_test.go | 8 +++--- describetable.go | 8 +----- describetable_test.go | 3 ++- go.mod | 2 +- put.go | 19 ++------------ put_test.go | 15 ++++++----- query.go | 46 ++++++---------------------------- query_test.go | 38 +++++++++++++++------------- retry.go | 19 ++------------ scan.go | 58 ++++++++----------------------------------- scan_test.go | 23 +++++++++-------- substitute.go | 2 +- table.go | 36 ++++++--------------------- table_test.go | 11 +++++--- ttl.go | 18 ++------------ ttl_test.go | 4 ++- tx.go | 27 +++----------------- tx_test.go | 38 +++++++++++++++------------- update.go | 49 ++++-------------------------------- update_test.go | 22 +++++++++------- updatetable.go | 8 +----- updatetable_test.go | 4 ++- 30 files changed, 177 insertions(+), 411 deletions(-) diff --git a/batch_test.go b/batch_test.go index 398e455..c49cf4a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" "time" ) @@ -12,6 +13,7 @@ func TestBatchGetWrite(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() items := make([]interface{}, batchSize) widgets := make(map[int]widget) @@ -29,7 +31,7 @@ func TestBatchGetWrite(t *testing.T) { } var wcc ConsumedCapacity - wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run() + wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run(ctx) if wrote != batchSize { t.Error("unexpected wrote:", wrote, "≠", batchSize) } @@ -48,7 +50,7 @@ func TestBatchGetWrite(t *testing.T) { Project("UserID", "Time"). Consistent(true). ConsumedCapacity(&cc). - All(&results) + All(ctx, &results) if err != nil { t.Error("unexpected error:", err) } @@ -73,7 +75,7 @@ func TestBatchGetWrite(t *testing.T) { // delete both wrote, err = table.Batch("UserID", "Time").Write(). - Delete(keys...).Run() + Delete(keys...).Run(ctx) if wrote != batchSize { t.Error("unexpected wrote:", wrote, "≠", batchSize) } @@ -86,7 +88,7 @@ func TestBatchGetWrite(t *testing.T) { err = table.Batch("UserID", "Time"). Get(keys...). Consistent(true). - All(&results) + All(ctx, &results) if err != ErrNotFound { t.Error("expected ErrNotFound, got", err) } @@ -100,15 +102,16 @@ func TestBatchGetEmptySets(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() now := time.Now().UnixNano() / 1000000000 id := int(now) entry := widget{UserID: id, Time: time.Now()} - if err := table.Put(entry).Run(); err != nil { + if err := table.Put(entry).Run(ctx); err != nil { panic(err) } entry2 := widget{UserID: id + batchSize*2, Time: entry.Time} - if err := table.Put(entry2).Run(); err != nil { + if err := table.Put(entry2).Run(ctx); err != nil { panic(err) } @@ -118,7 +121,7 @@ func TestBatchGetEmptySets(t *testing.T) { } results := []widget{} - err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results) + err := table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results) if err != nil { t.Error(err) } @@ -126,12 +129,12 @@ func TestBatchGetEmptySets(t *testing.T) { t.Error("batch get empty set, unexpected length:", len(results), "want:", 2) } - if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(); err != nil { + if err := table.Delete("UserID", entry.UserID).Range("Time", entry.Time).Run(ctx); err != nil { panic(err) } results = []widget{} - err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(&results) + err = table.Batch("UserID", "Time").Get(keysToCheck...).Consistent(true).All(ctx, &results) if err != nil { t.Error(err) } @@ -140,7 +143,7 @@ func TestBatchGetEmptySets(t *testing.T) { } results = []widget{} - err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(&results) + err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(ctx, &results) if err != ErrNotFound { t.Error(err) } @@ -150,14 +153,15 @@ func TestBatchGetEmptySets(t *testing.T) { } func TestBatchEmptyInput(t *testing.T) { + ctx := context.TODO() table := testDB.Table(testTable) var out []any - err := table.Batch("UserID", "Time").Get().All(&out) + err := table.Batch("UserID", "Time").Get().All(ctx, &out) if err != ErrNoInput { t.Error("unexpected error", err) } - _, err = table.Batch("UserID", "Time").Write().Run() + _, err = table.Batch("UserID", "Time").Write().Run(ctx) if err != ErrNoInput { t.Error("unexpected error", err) } diff --git a/batchget.go b/batchget.go index 8cd47a2..e5b8b03 100644 --- a/batchget.go +++ b/batchget.go @@ -118,17 +118,9 @@ func (bg *BatchGet) ConsumedCapacity(cc *ConsumedCapacity) *BatchGet { } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (bg *BatchGet) All(out interface{}) error { +func (bg *BatchGet) All(ctx context.Context, out interface{}) error { iter := newBGIter(bg, unmarshalAppendTo(out), bg.err) - for iter.Next(out) { - } - return iter.Err() -} - -// AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (bg *BatchGet) AllWithContext(ctx context.Context, out interface{}) error { - iter := newBGIter(bg, unmarshalAppendTo(out), bg.err) - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } return iter.Err() } @@ -216,13 +208,7 @@ func newBGIter(bg *BatchGet, fn unmarshalFunc, err error) *bgIter { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *bgIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *bgIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *bgIter) Next(ctx context.Context, out interface{}) bool { // stop if we have an error if ctx.Err() != nil { itr.err = ctx.Err() diff --git a/batchwrite.go b/batchwrite.go index 642b88b..93e84ba 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -67,13 +67,7 @@ func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite { // For batches with more than 25 operations, an error could indicate that // some records have been written and some have not. Consult the wrote // return amount to figure out which operations have succeeded. -func (bw *BatchWrite) Run() (wrote int, err error) { - ctx, cancel := defaultContext() - defer cancel() - return bw.RunWithContext(ctx) -} - -func (bw *BatchWrite) RunWithContext(ctx context.Context) (wrote int, err error) { +func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { if bw.err != nil { return 0, bw.err } diff --git a/createtable.go b/createtable.go index caec3d3..168798d 100644 --- a/createtable.go +++ b/createtable.go @@ -235,14 +235,7 @@ func (ct *CreateTable) SSEEncryption(enabled bool, keyID string, sseType SSEType } // Run creates this table or returns an error. -func (ct *CreateTable) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return ct.RunWithContext(ctx) -} - -// RunWithContext creates this table or returns an error. -func (ct *CreateTable) RunWithContext(ctx context.Context) error { +func (ct *CreateTable) Run(ctx context.Context) error { if ct.err != nil { return ct.err } @@ -255,18 +248,11 @@ func (ct *CreateTable) RunWithContext(ctx context.Context) error { } // Wait creates this table and blocks until it exists and is ready to use. -func (ct *CreateTable) Wait() error { - ctx, cancel := defaultContext() - defer cancel() - return ct.WaitWithContext(ctx) -} - -// WaitWithContext creates this table and blocks until it exists and is ready to use. -func (ct *CreateTable) WaitWithContext(ctx context.Context) error { - if err := ct.RunWithContext(ctx); err != nil { +func (ct *CreateTable) Wait(ctx context.Context) error { + if err := ct.Run(ctx); err != nil { return err } - return ct.db.Table(ct.tableName).WaitWithContext(ctx) + return ct.db.Table(ct.tableName).Wait(ctx) } func (ct *CreateTable) from(rv reflect.Value) error { diff --git a/db.go b/db.go index 51d80c7..22dc5a6 100644 --- a/db.go +++ b/db.go @@ -12,7 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" - "github.com/guregu/dynamo/dynamodbiface" + "github.com/guregu/dynamo/v2/dynamodbiface" ) // DB is a DynamoDB client. @@ -99,9 +99,9 @@ func (db *DB) Client() dynamodbiface.DynamoDBAPI { // return db // } -func (db *DB) log(format string, v ...interface{}) { - db.logger.Logf(logging.Debug, format, v...) -} +// func (db *DB) log(format string, v ...interface{}) { +// db.logger.Logf(logging.Debug, format, v...) +// } // ListTables is a request to list tables. // See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_ListTables.html @@ -115,18 +115,11 @@ func (db *DB) ListTables() *ListTables { } // All returns every table or an error. -func (lt *ListTables) All() ([]string, error) { - ctx, cancel := defaultContext() - defer cancel() - return lt.AllWithContext(ctx) -} - -// AllWithContext returns every table or an error. -func (lt *ListTables) AllWithContext(ctx context.Context) ([]string, error) { +func (lt *ListTables) All(ctx context.Context) ([]string, error) { var tables []string itr := lt.Iter() var name string - for itr.NextWithContext(ctx, &name) { + for itr.Next(ctx, &name) { tables = append(tables, name) } return tables, itr.Err() @@ -145,13 +138,7 @@ func (lt *ListTables) Iter() Iter { return <Iter{lt: lt} } -func (itr *ltIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *ltIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *ltIter) Next(ctx context.Context, out interface{}) bool { if ctx.Err() != nil { itr.err = ctx.Err() } @@ -214,10 +201,7 @@ func (itr *ltIter) input() *dynamodb.ListTablesInput { type Iter interface { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. - Next(out interface{}) bool - // NextWithContext tries to unmarshal the next result into out. - // Returns false when it is complete or if it runs into an error. - NextWithContext(ctx context.Context, out interface{}) bool + Next(ctx context.Context, out interface{}) bool // Err returns the error encountered, if any. // You should check this after Next is finished. Err() error diff --git a/db_test.go b/db_test.go index 0b757b8..a83f1f1 100644 --- a/db_test.go +++ b/db_test.go @@ -63,7 +63,7 @@ func TestListTables(t *testing.T) { t.Skip(offlineSkipMsg) } - tables, err := testDB.ListTables().All() + tables, err := testDB.ListTables().All(context.TODO()) if err != nil { t.Error(err) return diff --git a/decode_aux_test.go b/decode_aux_test.go index 5e6105d..5e7c570 100644 --- a/decode_aux_test.go +++ b/decode_aux_test.go @@ -7,7 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "github.com/guregu/dynamo" + "github.com/guregu/dynamo/v2" ) type Coffee struct { diff --git a/delete.go b/delete.go index 90cc996..85ca119 100644 --- a/delete.go +++ b/delete.go @@ -78,13 +78,7 @@ func (d *Delete) ConsumedCapacity(cc *ConsumedCapacity) *Delete { } // Run executes this delete request. -func (d *Delete) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return d.RunWithContext(ctx) -} - -func (d *Delete) RunWithContext(ctx context.Context) error { +func (d *Delete) Run(ctx context.Context) error { d.returnType = "NONE" _, err := d.run(ctx) return err @@ -92,13 +86,7 @@ func (d *Delete) RunWithContext(ctx context.Context) error { // OldValue executes this delete request, unmarshaling the previous value to out. // Returns ErrNotFound is there was no previous value. -func (d *Delete) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return d.OldValueWithContext(ctx, out) -} - -func (d *Delete) OldValueWithContext(ctx context.Context, out interface{}) error { +func (d *Delete) OldValue(ctx context.Context, out interface{}) error { d.returnType = "ALL_OLD" output, err := d.run(ctx) switch { diff --git a/delete_test.go b/delete_test.go index 751565d..6858678 100644 --- a/delete_test.go +++ b/delete_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "reflect" "testing" "time" @@ -11,6 +12,7 @@ func TestDelete(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() // first, add an item to delete later item := widget{ @@ -21,7 +23,7 @@ func TestDelete(t *testing.T) { "color": "octarine", }, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -31,7 +33,7 @@ func TestDelete(t *testing.T) { Range("Time", item.Time). If("Meta.'color' = ?", "octarine"). If("Msg = ?", "wrong msg"). - Run() + Run(ctx) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -39,7 +41,7 @@ func TestDelete(t *testing.T) { // delete it var old widget var cc ConsumedCapacity - err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(&old) + err = table.Delete("UserID", item.UserID).Range("Time", item.Time).ConsumedCapacity(&cc).OldValue(ctx, &old) if err != nil { t.Error("unexpected error:", err) } diff --git a/describetable.go b/describetable.go index e37cf4c..4c295ca 100644 --- a/describetable.go +++ b/describetable.go @@ -254,13 +254,7 @@ func (table Table) Describe() *DescribeTable { } // Run executes this request and describe the table. -func (dt *DescribeTable) Run() (Description, error) { - ctx, cancel := defaultContext() - defer cancel() - return dt.RunWithContext(ctx) -} - -func (dt *DescribeTable) RunWithContext(ctx context.Context) (Description, error) { +func (dt *DescribeTable) Run(ctx context.Context) (Description, error) { input := dt.input() var result *dynamodb.DescribeTableOutput diff --git a/describetable_test.go b/describetable_test.go index 34bc50e..894b951 100644 --- a/describetable_test.go +++ b/describetable_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -10,7 +11,7 @@ func TestDescribeTable(t *testing.T) { } table := testDB.Table(testTable) - desc, err := table.Describe().Run() + desc, err := table.Describe().Run(context.TODO()) if err != nil { t.Error(err) return diff --git a/go.mod b/go.mod index 9f44979..35af853 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/guregu/dynamo +module github.com/guregu/dynamo/v2 require ( github.com/aws/aws-sdk-go-v2 v1.24.1 diff --git a/put.go b/put.go index f3ca6f3..4d43b71 100644 --- a/put.go +++ b/put.go @@ -53,14 +53,7 @@ func (p *Put) ConsumedCapacity(cc *ConsumedCapacity) *Put { } // Run executes this put. -func (p *Put) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return p.RunWithContext(ctx) -} - -// Run executes this put. -func (p *Put) RunWithContext(ctx context.Context) error { +func (p *Put) Run(ctx context.Context) error { p.returnType = "NONE" _, err := p.run(ctx) return err @@ -68,15 +61,7 @@ func (p *Put) RunWithContext(ctx context.Context) error { // OldValue executes this put, unmarshaling the previous value into out. // Returns ErrNotFound is there was no previous value. -func (p *Put) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return p.OldValueWithContext(ctx, out) -} - -// OldValueWithContext executes this put, unmarshaling the previous value into out. -// Returns ErrNotFound is there was no previous value. -func (p *Put) OldValueWithContext(ctx context.Context, out interface{}) error { +func (p *Put) OldValue(ctx context.Context, out interface{}) error { p.returnType = "ALL_OLD" output, err := p.run(ctx) switch { diff --git a/put_test.go b/put_test.go index 50acd52..3c37edc 100644 --- a/put_test.go +++ b/put_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "reflect" "testing" "time" @@ -13,6 +14,7 @@ func TestPut(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() type widget2 struct { widget @@ -34,7 +36,7 @@ func TestPut(t *testing.T) { List: []*string{}, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -53,7 +55,7 @@ func TestPut(t *testing.T) { } var oldValue widget2 var cc ConsumedCapacity - err = table.Put(newItem).ConsumedCapacity(&cc).OldValue(&oldValue) + err = table.Put(newItem).ConsumedCapacity(&cc).OldValue(ctx, &oldValue) if err != nil { t.Error("unexpected error:", err) } @@ -67,7 +69,7 @@ func TestPut(t *testing.T) { } // putting the same item: this should fail - err = table.Put(newItem).If("attribute_not_exists(UserID)").If("attribute_not_exists('Time')").Run() + err = table.Put(newItem).If("attribute_not_exists(UserID)").If("attribute_not_exists('Time')").Run(ctx) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -78,6 +80,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() type awsWidget struct { XUserID int `dynamodbav:"UserID"` @@ -98,13 +101,13 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { XMsg: "hello world", } - err = table.Put(AWSEncoding(item)).Run() + err = table.Put(AWSEncoding(item)).Run(ctx) if err != nil { t.Error(err) } var result awsWidget - err = table.Get("UserID", item.XUserID).Range("Time", Equal, item.XTime).Consistent(true).One(AWSEncoding(&result)) + err = table.Get("UserID", item.XUserID).Range("Time", Equal, item.XTime).Consistent(true).One(ctx, AWSEncoding(&result)) if err != nil { t.Error(err) } @@ -113,7 +116,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { } var list []awsWidget - err = table.Get("UserID", item.XUserID).Consistent(true).All(AWSEncoding(&list)) + err = table.Get("UserID", item.XUserID).Consistent(true).All(ctx, AWSEncoding(&list)) if err != nil { t.Error(err) } diff --git a/query.go b/query.go index 14e2411..b2eaccc 100644 --- a/query.go +++ b/query.go @@ -197,13 +197,7 @@ func (q *Query) ConsumedCapacity(cc *ConsumedCapacity) *Query { // One executes this query and retrieves a single result, // unmarshaling the result to out. -func (q *Query) One(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return q.OneWithContext(ctx, out) -} - -func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { +func (q *Query) One(ctx context.Context, out interface{}) error { if q.err != nil { return q.err } @@ -267,13 +261,7 @@ func (q *Query) OneWithContext(ctx context.Context, out interface{}) error { } // Count executes this request, returning the number of results. -func (q *Query) Count() (int, error) { - ctx, cancel := defaultContext() - defer cancel() - return q.CountWithContext(ctx) -} - -func (q *Query) CountWithContext(ctx context.Context) (int, error) { +func (q *Query) Count(ctx context.Context) (int, error) { if q.err != nil { return 0, q.err } @@ -335,13 +323,7 @@ type queryIter struct { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *queryIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *queryIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *queryIter) Next(ctx context.Context, out interface{}) bool { // stop if we have an error if ctx.Err() != nil { itr.err = ctx.Err() @@ -404,7 +386,7 @@ func (itr *queryIter) NextWithContext(ctx context.Context, out interface{}) bool if len(itr.output.Items) == 0 { if itr.output.LastEvaluatedKey != nil { // we need to retry until we get some data - return itr.NextWithContext(ctx, out) + return itr.Next(ctx, out) } // we're done return false @@ -453,38 +435,26 @@ func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (q *Query) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return q.AllWithContext(ctx, out) -} - -func (q *Query) AllWithContext(ctx context.Context, out interface{}) error { +func (q *Query) All(ctx context.Context, out interface{}) error { iter := &queryIter{ query: q, unmarshal: unmarshalAppendTo(out), err: q.err, } - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } return iter.Err() } // AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice. // This returns a PagingKey you can use with StartFrom to split up results. -func (q *Query) AllWithLastEvaluatedKey(out interface{}) (PagingKey, error) { - ctx, cancel := defaultContext() - defer cancel() - return q.AllWithLastEvaluatedKeyContext(ctx, out) -} - -func (q *Query) AllWithLastEvaluatedKeyContext(ctx context.Context, out interface{}) (PagingKey, error) { +func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) { iter := &queryIter{ query: q, unmarshal: unmarshalAppendTo(out), err: q.err, } - for iter.NextWithContext(ctx, out) { + for iter.Next(ctx, out) { } lek, err := iter.LastEvaluatedKey(ctx) return lek, errors.Join(iter.Err(), err) diff --git a/query_test.go b/query_test.go index 2ffed0e..a3d714f 100644 --- a/query_test.go +++ b/query_test.go @@ -14,6 +14,7 @@ func TestGetAllCount(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() table := testDB.Table(testTable) // first, add an item to make sure there is at least one @@ -27,7 +28,7 @@ func TestGetAllCount(t *testing.T) { }, StrPtr: new(string), } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -52,7 +53,7 @@ func TestGetAllCount(t *testing.T) { Filter("StrPtr = ?", ""). Filter("?", lit). ConsumedCapacity(&cc1). - All(&result) + All(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -63,7 +64,7 @@ func TestGetAllCount(t *testing.T) { Filter("StrPtr = ?", ""). Filter("$", lit). // both $ and ? are OK for literals ConsumedCapacity(&cc2). - Count() + Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -92,7 +93,7 @@ func TestGetAllCount(t *testing.T) { // query specifically against the inserted item (using GetItem) var one widget - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -102,7 +103,7 @@ func TestGetAllCount(t *testing.T) { // query specifically against the inserted item (using Query) one = widget{} - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Filter("Msg = ?", item.Msg).Filter("StrPtr = ?", "").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Filter("Msg = ?", item.Msg).Filter("StrPtr = ?", "").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -116,7 +117,7 @@ func TestGetAllCount(t *testing.T) { UserID: item.UserID, Time: item.Time, } - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Project("UserID", "Time").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).Project("UserID", "Time").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -134,7 +135,7 @@ func TestGetAllCount(t *testing.T) { "animal.cow": "moo", }, } - err = table.Get("UserID", 42).Range("Time", Equal, item.Time).ProjectExpr("UserID, $, Meta.foo, Meta.$", "Time", "animal.cow").Consistent(true).One(&one) + err = table.Get("UserID", 42).Range("Time", Equal, item.Time).ProjectExpr("UserID, $, Meta.foo, Meta.$", "Time", "animal.cow").Consistent(true).One(ctx, &one) if err != nil { t.Error("unexpected error:", err) } @@ -147,6 +148,7 @@ func TestQueryPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() table := testDB.Table(testTable) widgets := []interface{}{ @@ -167,7 +169,7 @@ func TestQueryPaging(t *testing.T) { }, } - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Error("couldn't write paging prep data", err) return } @@ -175,14 +177,14 @@ func TestQueryPaging(t *testing.T) { itr := table.Get("UserID", 1969).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } @@ -198,6 +200,7 @@ func TestQueryMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.Background() table := testDB.Table(testTable) widgets := []interface{}{ @@ -219,7 +222,7 @@ func TestQueryMagicLEK(t *testing.T) { } t.Run("prepare data", func(t *testing.T) { - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Fatal(err) } }) @@ -228,14 +231,14 @@ func TestQueryMagicLEK(t *testing.T) { itr := table.Get("UserID", 1970).Filter("attribute_exists('Msg')").Limit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } @@ -265,14 +268,14 @@ func TestQueryMagicLEK(t *testing.T) { itr := table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).Limit(1).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if !reflect.DeepEqual(w, widgets[i]) { t.Error("bad result:", w, "≠", widgets[i]) } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } - more := itr.Next(&w) + more := itr.Next(ctx, &w) if more { t.Error("unexpected more", more) } @@ -290,10 +293,11 @@ func TestQueryBadKeys(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.Background() t.Run("hash key", func(t *testing.T) { var v interface{} - err := table.Get("UserID", "").Range("Time", Equal, "123").One(&v) + err := table.Get("UserID", "").Range("Time", Equal, "123").One(ctx, &v) if err == nil { t.Error("want error, got", err) } @@ -301,7 +305,7 @@ func TestQueryBadKeys(t *testing.T) { t.Run("range key", func(t *testing.T) { var v interface{} - err := table.Get("UserID", 123).Range("Time", Equal, "").One(&v) + err := table.Get("UserID", 123).Range("Time", Equal, "").One(ctx, &v) if err == nil { t.Error("want error, got", err) } diff --git a/retry.go b/retry.go index a8d7d4f..43b8d6e 100644 --- a/retry.go +++ b/retry.go @@ -8,13 +8,10 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/smithy-go" + awstime "github.com/aws/smithy-go/time" "github.com/cenkalti/backoff/v4" ) -func defaultContext() (context.Context, context.CancelFunc) { - return context.Background(), func() {} -} - func (db *DB) retry(ctx context.Context, f func() error) error { // if a custom retryer has been set, the SDK will retry for us if db.retryer != nil { @@ -36,7 +33,7 @@ func (db *DB) retry(ctx context.Context, f func() error) error { if next = b.NextBackOff(); next == backoff.Stop { return err } - if err := sleep(ctx, next); err != nil { + if err := awstime.SleepWithContext(ctx, next); err != nil { return err } } @@ -87,15 +84,3 @@ func canRetry(err error) bool { return false } - -func sleep(ctx context.Context, dur time.Duration) error { - timer := time.NewTimer(dur) - defer timer.Stop() - - select { - case <-ctx.Done(): - case <-timer.C: - } - - return ctx.Err() -} diff --git a/scan.go b/scan.go index 3dafa3b..ed4d940 100644 --- a/scan.go +++ b/scan.go @@ -164,41 +164,26 @@ func (s *Scan) IterParallelStartFrom(ctx context.Context, keys []PagingKey) Para } // All executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (s *Scan) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return s.AllWithContext(ctx, out) -} - -// AllWithContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -func (s *Scan) AllWithContext(ctx context.Context, out interface{}) error { +func (s *Scan) All(ctx context.Context, out interface{}) error { itr := &scanIter{ scan: s, unmarshal: unmarshalAppendTo(out), err: s.err, } - for itr.NextWithContext(ctx, out) { + for itr.Next(ctx, out) { } return itr.Err() } // AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice. // It returns a key you can use with StartWith to continue this query. -func (s *Scan) AllWithLastEvaluatedKey(out interface{}) (PagingKey, error) { - ctx, cancel := defaultContext() - defer cancel() - return s.AllWithLastEvaluatedKeyContext(ctx, out) -} - -// AllWithLastEvaluatedKeyContext executes this request and unmarshals all results to out, which must be a pointer to a slice. -// It returns a key you can use with StartWith to continue this query. -func (s *Scan) AllWithLastEvaluatedKeyContext(ctx context.Context, out interface{}) (PagingKey, error) { +func (s *Scan) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) { itr := &scanIter{ scan: s, unmarshal: unmarshalAppendTo(out), err: s.err, } - for itr.NextWithContext(ctx, out) { + for itr.Next(ctx, out) { } lek, err := itr.LastEvaluatedKey(ctx) return lek, errors.Join(itr.Err(), err) @@ -209,7 +194,7 @@ func (s *Scan) AllParallel(ctx context.Context, segments int, out interface{}) e iters := s.newSegments(segments, nil) ps := newParallelScan(iters, s.cc, true, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } return ps.Err() } @@ -220,7 +205,7 @@ func (s *Scan) AllParallelWithLastEvaluatedKeys(ctx context.Context, segments in iters := s.newSegments(segments, nil) ps := newParallelScan(iters, s.cc, false, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } leks, err := ps.LastEvaluatedKeys(ctx) return leks, errors.Join(ps.Err(), err) @@ -232,7 +217,7 @@ func (s *Scan) AllParallelStartFrom(ctx context.Context, keys []PagingKey, out i iters := s.newSegments(len(keys), keys) ps := newParallelScan(iters, s.cc, false, unmarshalAppendTo(out)) go ps.run(ctx) - for ps.NextWithContext(ctx, out) { + for ps.Next(ctx, out) { } leks, err := ps.LastEvaluatedKeys(ctx) return leks, errors.Join(ps.Err(), err) @@ -241,16 +226,7 @@ func (s *Scan) AllParallelStartFrom(ctx context.Context, keys []PagingKey, out i // Count executes this request and returns the number of items matching the scan. // It takes into account the filter, limit, search limit, and all other parameters given. // It may return a higher count than the limits. -func (s *Scan) Count() (int, error) { - ctx, cancel := defaultContext() - defer cancel() - return s.CountWithContext(ctx) -} - -// CountWithContext executes this request and returns the number of items matching the scan. -// It takes into account the filter, limit, search limit, and all other parameters given. -// It may return a higher count than the limits. -func (s *Scan) CountWithContext(ctx context.Context) (int, error) { +func (s *Scan) Count(ctx context.Context) (int, error) { if s.err != nil { return 0, s.err } @@ -357,13 +333,7 @@ type scanIter struct { // Next tries to unmarshal the next result into out. // Returns false when it is complete or if it runs into an error. -func (itr *scanIter) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return itr.NextWithContext(ctx, out) -} - -func (itr *scanIter) NextWithContext(ctx context.Context, out interface{}) bool { +func (itr *scanIter) Next(ctx context.Context, out interface{}) bool { redo: // stop if we have an error if ctx.Err() != nil { @@ -512,7 +482,7 @@ func (ps *parallelScan) run(ctx context.Context) { } grp.Go(func() error { var item Item - for iter.NextWithContext(ctx, &item) { + for iter.Next(ctx, &item) { select { case <-ctx.Done(): return ctx.Err() @@ -548,13 +518,7 @@ func (ps *parallelScan) run(ctx context.Context) { close(ps.items) } -func (ps *parallelScan) Next(out interface{}) bool { - ctx, cancel := defaultContext() - defer cancel() - return ps.NextWithContext(ctx, out) -} - -func (ps *parallelScan) NextWithContext(ctx context.Context, out interface{}) bool { +func (ps *parallelScan) Next(ctx context.Context, out interface{}) bool { select { case <-ctx.Done(): ps.setError(ctx.Err()) diff --git a/scan_test.go b/scan_test.go index 72ecc48..b4d0ef3 100644 --- a/scan_test.go +++ b/scan_test.go @@ -12,6 +12,7 @@ func TestScan(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() // first, add an item to make sure there is at least one item := widget{ @@ -19,13 +20,13 @@ func TestScan(t *testing.T) { Time: time.Now().UTC(), Msg: "hello", } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } // count items via Query - ct, err := table.Get("UserID", 42).Consistent(true).Count() + ct, err := table.Get("UserID", 42).Consistent(true).Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -34,7 +35,7 @@ func TestScan(t *testing.T) { t.Run("All", func(t *testing.T) { var result []widget var cc ConsumedCapacity - err = table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc).All(&result) + err = table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc).All(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -61,7 +62,7 @@ func TestScan(t *testing.T) { // check this against Scan's count, too t.Run("Count", func(t *testing.T) { var cc2 ConsumedCapacity - scanCt, err := table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc2).Count() + scanCt, err := table.Scan().Filter("UserID = ?", 42).Consistent(true).ConsumedCapacity(&cc2).Count(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -106,6 +107,7 @@ func TestScanPaging(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() // prepare data insert := make([]interface{}, 10) @@ -116,7 +118,7 @@ func TestScanPaging(t *testing.T) { Msg: "garbage", } } - if _, err := table.Batch().Write().Put(insert...).Run(); err != nil { + if _, err := table.Batch().Write().Put(insert...).Run(ctx); err != nil { t.Fatal(err) } @@ -124,7 +126,7 @@ func TestScanPaging(t *testing.T) { widgets := [10]widget{} itr := table.Scan().Consistent(true).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { - more := itr.Next(&widgets[i]) + more := itr.Next(ctx, &widgets[i]) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } @@ -152,7 +154,7 @@ func TestScanPaging(t *testing.T) { for i := 0; i < len(widgets)/segments; i++ { var more bool for j := 0; j < segments; j++ { - more = itr.Next(&widgets[i*segments+j]) + more = itr.Next(ctx, &widgets[i*segments+j]) if !more && j != segments-1 { t.Error("bad number of results from parallel scan") } @@ -182,6 +184,7 @@ func TestScanMagicLEK(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.Background() widgets := []interface{}{ widget{ @@ -201,7 +204,7 @@ func TestScanMagicLEK(t *testing.T) { }, } // prepare data - if _, err := table.Batch().Write().Put(widgets...).Run(); err != nil { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { t.Fatal(err) } @@ -209,7 +212,7 @@ func TestScanMagicLEK(t *testing.T) { itr := table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").Limit(2).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } @@ -225,7 +228,7 @@ func TestScanMagicLEK(t *testing.T) { itr := table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).Limit(2).Iter() for i := 0; i < len(widgets); i++ { var w widget - itr.Next(&w) + itr.Next(ctx, &w) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) } diff --git a/substitute.go b/substitute.go index b81e78e..65c2631 100644 --- a/substitute.go +++ b/substitute.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/guregu/dynamo/internal/exprs" + "github.com/guregu/dynamo/v2/internal/exprs" ) // subber is a "mixin" for operators for keep track of subtituted keys and values diff --git a/table.go b/table.go index 559e4bb..f3235f5 100644 --- a/table.go +++ b/table.go @@ -54,15 +54,7 @@ func (table Table) Name() string { // Wait blocks until this table's status matches any status provided by want. // If no statuses are specified, the active status is used. -func (table Table) Wait(want ...Status) error { - ctx, cancel := defaultContext() - defer cancel() - return table.WaitWithContext(ctx, want...) -} - -// Wait blocks until this table's status matches any status provided by want. -// If no statuses are specified, the active status is used. -func (table Table) WaitWithContext(ctx context.Context, want ...Status) error { +func (table Table) Wait(ctx context.Context, want ...Status) error { if len(want) == 0 { want = []Status{ActiveStatus} } @@ -74,7 +66,7 @@ func (table Table) WaitWithContext(ctx context.Context, want ...Status) error { } err := table.db.retry(ctx, func() error { - desc, err := table.Describe().RunWithContext(ctx) + desc, err := table.Describe().Run(ctx) var aerr smithy.APIError if errors.As(err, &aerr) { if aerr.ErrorCode() == "ResourceNotFoundException" { @@ -134,7 +126,7 @@ func (table Table) primaryKeys(ctx context.Context, lek, esk Item, index string) keys := make(map[string]struct{}) err := table.db.retry(ctx, func() error { - desc, err := table.Describe().RunWithContext(ctx) + desc, err := table.Describe().Run(ctx) if err != nil { return err } @@ -181,14 +173,7 @@ func (table Table) DeleteTable() *DeleteTable { } // Run executes this request and deletes the table. -func (dt *DeleteTable) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return dt.RunWithContext(ctx) -} - -// RunWithContext executes this request and deletes the table. -func (dt *DeleteTable) RunWithContext(ctx context.Context) error { +func (dt *DeleteTable) Run(ctx context.Context) error { input := dt.input() return dt.table.db.retry(ctx, func() error { _, err := dt.table.db.client.DeleteTable(ctx, input) @@ -197,18 +182,11 @@ func (dt *DeleteTable) RunWithContext(ctx context.Context) error { } // Wait executes this request and blocks until the table is finished deleting. -func (dt *DeleteTable) Wait() error { - ctx, cancel := defaultContext() - defer cancel() - return dt.WaitWithContext(ctx) -} - -// WaitWithContext executes this request and blocks until the table is finished deleting. -func (dt *DeleteTable) WaitWithContext(ctx context.Context) error { - if err := dt.RunWithContext(ctx); err != nil { +func (dt *DeleteTable) Wait(ctx context.Context) error { + if err := dt.Run(ctx); err != nil { return err } - return dt.table.WaitWithContext(ctx, NotExistsStatus) + return dt.table.Wait(ctx, NotExistsStatus) } func (dt *DeleteTable) input() *dynamodb.DeleteTableInput { diff --git a/table_test.go b/table_test.go index 30d6195..ca356cc 100644 --- a/table_test.go +++ b/table_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "fmt" "reflect" "sort" @@ -19,6 +20,8 @@ func TestTableLifecycle(t *testing.T) { t.SkipNow() } + ctx := context.TODO() + now := time.Now().UTC() name := fmt.Sprintf("TestDB-%d", now.UnixNano()) @@ -37,11 +40,11 @@ func TestTableLifecycle(t *testing.T) { HashKeyType: StringType, RangeKey: "Bar", RangeKeyType: NumberType, - }).Wait(); err != nil { + }).Wait(ctx); err != nil { t.Fatal(err) } - desc, err := testDB.Table(name).Describe().Run() + desc, err := testDB.Table(name).Describe().Run(ctx) if err != nil { t.Fatal(err) } @@ -114,12 +117,12 @@ func TestTableLifecycle(t *testing.T) { // make sure it really works table := testDB.Table(name) - if err := table.Put(UserAction{UserID: "test", Time: now, Seq: 1, UUID: "42"}).Run(); err != nil { + if err := table.Put(UserAction{UserID: "test", Time: now, Seq: 1, UUID: "42"}).Run(ctx); err != nil { t.Fatal(err) } // delete & wait - if err := testDB.Table(name).DeleteTable().Wait(); err != nil { + if err := testDB.Table(name).DeleteTable().Wait(ctx); err != nil { t.Fatal(err) } } diff --git a/ttl.go b/ttl.go index 2e23795..10f9e65 100644 --- a/ttl.go +++ b/ttl.go @@ -32,14 +32,7 @@ func (table Table) UpdateTTL(attribute string, enabled bool) *UpdateTTL { } // Run executes this request. -func (ttl *UpdateTTL) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return ttl.RunWithContext(ctx) -} - -// RunWithContext executes this request. -func (ttl *UpdateTTL) RunWithContext(ctx context.Context) error { +func (ttl *UpdateTTL) Run(ctx context.Context) error { input := ttl.input() err := ttl.table.db.retry(ctx, func() error { @@ -70,14 +63,7 @@ func (table Table) DescribeTTL() *DescribeTTL { } // Run executes this request and returns details about time to live, or an error. -func (d *DescribeTTL) Run() (TTLDescription, error) { - ctx, cancel := defaultContext() - defer cancel() - return d.RunWithContext(ctx) -} - -// RunWithContext executes this request and returns details about time to live, or an error. -func (d *DescribeTTL) RunWithContext(ctx context.Context) (TTLDescription, error) { +func (d *DescribeTTL) Run(ctx context.Context) (TTLDescription, error) { input := d.input() var result *dynamodb.DescribeTimeToLiveOutput diff --git a/ttl_test.go b/ttl_test.go index 9ded4e3..4076647 100644 --- a/ttl_test.go +++ b/ttl_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -9,8 +10,9 @@ func TestDescribeTTL(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() - desc, err := table.DescribeTTL().Run() + desc, err := table.DescribeTTL().Run(ctx) if err != nil { t.Error(err) return diff --git a/tx.go b/tx.go index c7f9027..1caecd5 100644 --- a/tx.go +++ b/tx.go @@ -61,14 +61,7 @@ func (tx *GetTx) ConsumedCapacity(cc *ConsumedCapacity) *GetTx { } // Run executes this transaction and unmarshals everything specified by GetOne. -func (tx *GetTx) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return tx.RunWithContext(ctx) -} - -// RunWithContext executes this transaction and unmarshals everything specified by GetOne. -func (tx *GetTx) RunWithContext(ctx context.Context) error { +func (tx *GetTx) Run(ctx context.Context) error { input, err := tx.input() if err != nil { return err @@ -108,14 +101,7 @@ func (tx *GetTx) unmarshal(resp *dynamodb.TransactGetItemsOutput) error { } // All executes this transaction and unmarshals every value to out, which must be a pointer to a slice. -func (tx *GetTx) All(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return tx.AllWithContext(ctx, out) -} - -// AllWithContext executes this transaction and unmarshals every value to out, which must be a pointer to a slice. -func (tx *GetTx) AllWithContext(ctx context.Context, out interface{}) error { +func (tx *GetTx) All(ctx context.Context, out interface{}) error { input, err := tx.input() if err != nil { return err @@ -260,14 +246,7 @@ func (tx *WriteTx) ConsumedCapacity(cc *ConsumedCapacity) *WriteTx { } // Run executes this transaction. -func (tx *WriteTx) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return tx.RunWithContext(ctx) -} - -// RunWithContext executes this transaction. -func (tx *WriteTx) RunWithContext(ctx context.Context) error { +func (tx *WriteTx) Run(ctx context.Context) error { if tx.err != nil { return tx.err } diff --git a/tx_test.go b/tx_test.go index 64a6743..eb09864 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "errors" "reflect" "sync" @@ -15,6 +16,8 @@ func TestTx(t *testing.T) { t.Skip(offlineSkipMsg) } + ctx := context.TODO() + date1 := time.Date(1969, 1, 1, 1, 1, 1, 0, time.UTC) date2 := time.Date(1969, 2, 2, 2, 2, 2, 0, time.UTC) date3 := time.Date(1969, 3, 3, 3, 3, 3, 0, time.UTC) @@ -30,7 +33,7 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget2)) tx.Check(table.Check("UserID", 69).Range("Time", date3).IfNotExists()) tx.ConsumedCapacity(&cc) - err := tx.Run() + err := tx.Run(ctx) if err != nil { t.Error(err) } @@ -39,7 +42,7 @@ func TestTx(t *testing.T) { } ccold = cc - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -64,7 +67,7 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget1)) tx.Put(table.Put(widget2)) tx.ConsumedCapacity(&cc) - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -73,7 +76,7 @@ func TestTx(t *testing.T) { } ccold = cc - err = tx.Run() + err = tx.Run(ctx) if err != nil { t.Error(err) } @@ -95,7 +98,7 @@ func TestTx(t *testing.T) { getTx.GetOne(table.Get("UserID", 69).Range("Time", Equal, date2), &record2) getTx.GetOne(table.Get("UserID", 69).Range("Time", Equal, date3), &record3) getTx.ConsumedCapacity(&cc2) - err = getTx.Run() + err = getTx.Run(ctx) if err != nil { t.Error(err) } @@ -115,7 +118,7 @@ func TestTx(t *testing.T) { // All oldCC2 := cc2 var records []widget - err = getTx.All(&records) + err = getTx.All(ctx, &records) if err != nil { t.Error(err) } @@ -131,7 +134,7 @@ func TestTx(t *testing.T) { tx = testDB.WriteTx() tx.Check(table.Check("UserID", widget1.UserID).Range("Time", widget1.Time).If("Msg = ?", widget1.Msg)) tx.Update(table.Update("UserID", widget2.UserID).Range("Time", widget2.Time).Set("Msg", widget2.Msg)) - if err = tx.Run(); err != nil { + if err = tx.Run(ctx); err != nil { t.Error(err) } @@ -139,12 +142,12 @@ func TestTx(t *testing.T) { tx = testDB.WriteTx() tx.Delete(table.Delete("UserID", widget1.UserID).Range("Time", widget1.Time).If("Msg = ?", widget1.Msg)) tx.Delete(table.Delete("UserID", widget2.UserID).Range("Time", widget2.Time).If("Msg = ?", widget2.Msg)) - if err = tx.Run(); err != nil { + if err = tx.Run(ctx); err != nil { t.Error(err) } // zero results - if err = getTx.Run(); err != ErrNotFound { + if err = getTx.Run(ctx); err != ErrNotFound { t.Error("expected ErrNotFound, got:", err) } @@ -153,7 +156,7 @@ func TestTx(t *testing.T) { tx.Put(table.Put(widget{UserID: 69, Time: date1}).If("'Msg' = ?", "should not exist")) tx.Put(table.Put(widget{UserID: 69, Time: date2})) tx.Check(table.Check("UserID", 69).Range("Time", date3).IfExists().If("Msg = ?", "don't exist foo")) - err = tx.Run() + err = tx.Run(ctx) if err == nil { t.Error("expected error") } else { @@ -167,12 +170,12 @@ func TestTx(t *testing.T) { t.Logf("All: %+v (len: %d)", records, len(records)) // no input - err = testDB.GetTx().All(nil) + err = testDB.GetTx().All(ctx, nil) if err != ErrNoInput { t.Error("unexpected error", err) } - err = testDB.WriteTx().Run() + err = testDB.WriteTx().Run(ctx) if err != ErrNoInput { t.Error("unexpected error", err) } @@ -182,12 +185,13 @@ func TestTxRetry(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } + ctx := context.TODO() date1 := time.Date(1999, 1, 1, 1, 1, 1, 0, time.UTC) widget1 := widget{UserID: 69, Time: date1, Msg: "dog", Count: 0} table := testDB.Table(testTable) - if err := table.Put(widget1).Run(); err != nil { + if err := table.Put(widget1).Run(ctx); err != nil { t.Fatal(err) } @@ -204,7 +208,7 @@ func TestTxRetry(t *testing.T) { tx.Update(table.Update("UserID", widget1.UserID). Range("Time", widget1.Time). Add("Count", 1)) - if err := tx.Run(); err != nil { + if err := tx.Run(ctx); err != nil { // spew.Dump(err) panic(err) } @@ -219,7 +223,7 @@ func TestTxRetry(t *testing.T) { tx.Update(table.Update("UserID", widget1.UserID). Range("Time", widget1.Time).Add("Count", 1). If("'Count' = ?", -1)) - if err := tx.Run(); err != nil && !IsCondCheckFailed(err) { + if err := tx.Run(ctx); err != nil && !IsCondCheckFailed(err) { panic(err) } }() @@ -230,13 +234,13 @@ func TestTxRetry(t *testing.T) { defer wg.Done() tx := testDB.WriteTx() tx.Update(table.Update("UserID", "\u0002").Set("Foo", "")) - _ = tx.Run() + _ = tx.Run(ctx) }() wg.Wait() var got widget - if err := table.Get("UserID", widget1.UserID).Range("Time", Equal, widget1.Time).One(&got); err != nil { + if err := table.Get("UserID", widget1.UserID).Range("Time", Equal, widget1.Time).One(ctx, &got); err != nil { t.Fatal(err) } diff --git a/update.go b/update.go index 96a67a1..f841769 100644 --- a/update.go +++ b/update.go @@ -287,14 +287,7 @@ func (u *Update) ConsumedCapacity(cc *ConsumedCapacity) *Update { } // Run executes this update. -func (u *Update) Run() error { - ctx, cancel := defaultContext() - defer cancel() - return u.RunWithContext(ctx) -} - -// RunWithContext executes this update. -func (u *Update) RunWithContext(ctx context.Context) error { +func (u *Update) Run(ctx context.Context) error { u.returnType = "NONE" _, err := u.run(ctx) return err @@ -302,15 +295,7 @@ func (u *Update) RunWithContext(ctx context.Context) error { // Value executes this update, encoding out with the new value after the update. // This is equivalent to ReturnValues = ALL_NEW in the DynamoDB API. -func (u *Update) Value(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.ValueWithContext(ctx, out) -} - -// ValueWithContext executes this update, encoding out with the new value after the update. -// This is equivalent to ReturnValues = ALL_NEW in the DynamoDB API. -func (u *Update) ValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) Value(ctx context.Context, out interface{}) error { u.returnType = "ALL_NEW" output, err := u.run(ctx) if err != nil { @@ -321,15 +306,7 @@ func (u *Update) ValueWithContext(ctx context.Context, out interface{}) error { // OldValue executes this update, encoding out with the old value before the update. // This is equivalent to ReturnValues = ALL_OLD in the DynamoDB API. -func (u *Update) OldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OldValueWithContext(ctx, out) -} - -// OldValueWithContext executes this update, encoding out with the old value before the update. -// This is equivalent to ReturnValues = ALL_OLD in the DynamoDB API. -func (u *Update) OldValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OldValue(ctx context.Context, out interface{}) error { u.returnType = "ALL_OLD" output, err := u.run(ctx) if err != nil { @@ -340,15 +317,7 @@ func (u *Update) OldValueWithContext(ctx context.Context, out interface{}) error // OnlyUpdatedValue executes this update, encoding out with only with new values of the attributes that were changed. // This is equivalent to ReturnValues = UPDATED_NEW in the DynamoDB API. -func (u *Update) OnlyUpdatedValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OnlyUpdatedValueWithContext(ctx, out) -} - -// OnlyUpdatedValueWithContext executes this update, encoding out with only with new values of the attributes that were changed. -// This is equivalent to ReturnValues = UPDATED_NEW in the DynamoDB API. -func (u *Update) OnlyUpdatedValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OnlyUpdatedValue(ctx context.Context, out interface{}) error { u.returnType = "UPDATED_NEW" output, err := u.run(ctx) if err != nil { @@ -359,15 +328,7 @@ func (u *Update) OnlyUpdatedValueWithContext(ctx context.Context, out interface{ // OnlyUpdatedOldValue executes this update, encoding out with only with old values of the attributes that were changed. // This is equivalent to ReturnValues = UPDATED_OLD in the DynamoDB API. -func (u *Update) OnlyUpdatedOldValue(out interface{}) error { - ctx, cancel := defaultContext() - defer cancel() - return u.OnlyUpdatedOldValueWithContext(ctx, out) -} - -// OnlyUpdatedOldValueWithContext executes this update, encoding out with only with old values of the attributes that were changed. -// This is equivalent to ReturnValues = UPDATED_OLD in the DynamoDB API. -func (u *Update) OnlyUpdatedOldValueWithContext(ctx context.Context, out interface{}) error { +func (u *Update) OnlyUpdatedOldValue(ctx context.Context, out interface{}) error { u.returnType = "UPDATED_OLD" output, err := u.run(ctx) if err != nil { diff --git a/update_test.go b/update_test.go index 54fd962..c3a29e1 100644 --- a/update_test.go +++ b/update_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "reflect" "testing" "time" @@ -14,6 +15,7 @@ func TestUpdate(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() type widget2 struct { widget @@ -40,7 +42,7 @@ func TestUpdate(t *testing.T) { MySet2: map[string]struct{}{"a": {}, "b": {}, "bad1": {}, "c": {}, "bad2": {}}, MySet3: map[int64]struct{}{1: {}, 999: {}, 2: {}, 3: {}, 555: {}}, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) } @@ -90,7 +92,7 @@ func TestUpdate(t *testing.T) { DeleteFromSet("MySet2", []string{"bad1", "bad2"}). DeleteFromSet("MySet3", map[int64]struct{}{999: {}, 555: {}}). ConsumedCapacity(&cc). - Value(&result) + Value(ctx, &result) expected := widget2{ widget: widget{ @@ -130,7 +132,7 @@ func TestUpdate(t *testing.T) { Range("Time", item.Time). Set("Msg", expected2.Msg). Add("Count", 1). - OnlyUpdatedValue(&updated) + OnlyUpdatedValue(ctx, &updated) if err != nil { t.Error("unexpected error:", err) } @@ -143,7 +145,7 @@ func TestUpdate(t *testing.T) { Range("Time", item.Time). Set("Msg", "this shouldn't be seen"). Add("Count", 100). - OnlyUpdatedOldValue(&updatedOld) + OnlyUpdatedOldValue(ctx, &updatedOld) if err != nil { t.Error("unexpected error:", err) } @@ -158,7 +160,7 @@ func TestUpdate(t *testing.T) { Add("Count", 1). If("'Count' > ?", 100). If("(MeaningOfLife = ?)", 42). - Value(&result) + Value(ctx, &result) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } @@ -169,6 +171,7 @@ func TestUpdateNil(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() // first, add an item to make sure there is at least one item := widget{ @@ -180,7 +183,7 @@ func TestUpdateNil(t *testing.T) { }, Count: 100, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) t.FailNow() @@ -199,7 +202,7 @@ func TestUpdateNil(t *testing.T) { Set("Meta.'ok'", (*ptrTextMarshaler)(nil)). SetExpr("'Count' = ?", (*textMarshaler)(nil)). SetExpr("MsgPtr = ?", ""). - Value(&result) + Value(ctx, &result) if err != nil { t.Error("unexpected error:", err) } @@ -224,6 +227,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() type widget2 struct { widget @@ -241,7 +245,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { CStr: customString("delete me"), SPtr: &str, } - err := table.Put(item).Run() + err := table.Put(item).Run(ctx) if err != nil { t.Error("unexpected error:", err) t.FailNow() @@ -252,7 +256,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { err = table.Update("UserID", item.UserID).Range("Time", item.Time). Set("CStr", customString("")). Set("SPtr", nil). - Value(&result) + Value(ctx, &result) if err != nil { t.Error("unexpected error:", err) } diff --git a/updatetable.go b/updatetable.go index bade515..1b6eae3 100644 --- a/updatetable.go +++ b/updatetable.go @@ -107,13 +107,7 @@ func (ut *UpdateTable) DisableStream() *UpdateTable { } // Run executes this request and describes the table. -func (ut *UpdateTable) Run() (Description, error) { - ctx, cancel := defaultContext() - defer cancel() - return ut.RunWithContext(ctx) -} - -func (ut *UpdateTable) RunWithContext(ctx context.Context) (Description, error) { +func (ut *UpdateTable) Run(ctx context.Context) (Description, error) { if ut.err != nil { return Description{}, ut.err } diff --git a/updatetable_test.go b/updatetable_test.go index 04c2c1b..5b1e863 100644 --- a/updatetable_test.go +++ b/updatetable_test.go @@ -1,6 +1,7 @@ package dynamo import ( + "context" "testing" ) @@ -10,6 +11,7 @@ func _TestUpdateTable(t *testing.T) { t.Skip(offlineSkipMsg) } table := testDB.Table(testTable) + ctx := context.TODO() desc, err := table.UpdateTable().CreateIndex(Index{ Name: "test123", @@ -23,7 +25,7 @@ func _TestUpdateTable(t *testing.T) { Read: 1, Write: 1, }, - }).Run() + }).Run(ctx) // desc, err := table.UpdateTable().DeleteIndex("test123").Run()