diff --git a/batch_test.go b/batch_test.go index a508ba5..a4b4df3 100644 --- a/batch_test.go +++ b/batch_test.go @@ -137,10 +137,23 @@ func TestBatchGetEmptySets(t *testing.T) { results = []widget{} err = table.Batch("UserID", "Time").Get(keysToCheck[:len(keysToCheck)-1]...).Consistent(true).All(&results) - if err != nil { + if err != ErrNotFound { t.Error(err) } if len(results) != 0 { t.Error("batch get empty set, unexpected length:", len(results), "want:", 0) } } + +func TestBatchEmptyInput(t *testing.T) { + table := testDB.Table(testTable) + err := table.Batch("UserID", "Time").Get().All(nil) + if err != ErrNoInput { + t.Error("unexpected error", err) + } + + _, err = table.Batch("UserID", "Time").Write().Run() + if err != ErrNoInput { + t.Error("unexpected error", err) + } +} diff --git a/batchget.go b/batchget.go index c19d3a1..9877f62 100644 --- a/batchget.go +++ b/batchget.go @@ -179,6 +179,10 @@ type bgIter struct { } func newBGIter(bg *BatchGet, fn unmarshalFunc, err error) *bgIter { + if err == nil && len(bg.reqs) == 0 { + err = ErrNoInput + } + iter := &bgIter{ bg: bg, err: err, diff --git a/batchwrite.go b/batchwrite.go index 76514c8..d44ac40 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -75,6 +75,9 @@ func (bw *BatchWrite) RunWithContext(ctx aws.Context) (wrote int, err error) { if bw.err != nil { return 0, bw.err } + if len(bw.ops) == 0 { + return 0, ErrNoInput + } // TODO: this could be made to be more efficient, // by combining unprocessed items with the next request. diff --git a/tx.go b/tx.go index c146e9b..db18f4b 100644 --- a/tx.go +++ b/tx.go @@ -1,11 +1,17 @@ package dynamo import ( + "errors" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/gofrs/uuid" ) +// ErrNoInput is returned when APIs that can take multiple inputs are run with zero inputs. +// For example, in a transaction with no operations. +var ErrNoInput = errors.New("dynamo: no input items") + type getTxOp interface { getTxItem() (*dynamodb.TransactGetItem, error) } @@ -143,6 +149,9 @@ func (tx *GetTx) AllWithContext(ctx aws.Context, out interface{}) error { } func (tx *GetTx) input() (*dynamodb.TransactGetItemsInput, error) { + if len(tx.items) == 0 { + return nil, ErrNoInput + } input := &dynamodb.TransactGetItemsInput{} for _, item := range tx.items { tgi, err := item.getTxItem() @@ -269,6 +278,9 @@ func (tx *WriteTx) RunWithContext(ctx aws.Context) error { } func (tx *WriteTx) input() (*dynamodb.TransactWriteItemsInput, error) { + if len(tx.items) == 0 { + return nil, ErrNoInput + } input := &dynamodb.TransactWriteItemsInput{} for _, item := range tx.items { wti, err := item.writeTxItem() diff --git a/tx_test.go b/tx_test.go index fcef71f..a3a3873 100644 --- a/tx_test.go +++ b/tx_test.go @@ -164,4 +164,15 @@ func TestTx(t *testing.T) { t.Logf("1: %+v 2: %+v 3: %+v", record1, record2, record3) t.Logf("All: %+v (len: %d)", records, len(records)) + + // no input + err = testDB.GetTx().All(nil) + if err != ErrNoInput { + t.Error("unexpected error", err) + } + + err = testDB.WriteTx().Run() + if err != ErrNoInput { + t.Error("unexpected error", err) + } }