diff --git a/README.md b/README.md index ebf7691..5a30d80 100644 --- a/README.md +++ b/README.md @@ -119,8 +119,14 @@ err := db.Table("Books").Get("ID", 555).One(dynamo.AWSEncoding(&someBook)) By default, tests are run in offline mode. Create a table called `TestDB`, with a Number Parition Key called `UserID` and a String Sort Key called `Time`. Change the table name with the environment variable `DYNAMO_TEST_TABLE`. You must specify `DYNAMO_TEST_REGION`, setting it to the AWS region where your test table is. - ```bash +```bash DYNAMO_TEST_REGION=us-west-2 go test github.com/guregu/dynamo/... -cover +``` + +Or simply run the following command to test it locally : + + ```bash +./run_tests.sh ``` ### License diff --git a/batch_test.go b/batch_test.go index 44d1bd4..1aa4460 100644 --- a/batch_test.go +++ b/batch_test.go @@ -7,7 +7,7 @@ import ( const batchSize = 101 -func TestBatchGetWrite(t *testing.T) { +func testBatchGetWrite(t *testing.T, isSequential bool) { if testDB == nil { t.Skip(offlineSkipMsg) } @@ -29,7 +29,13 @@ func TestBatchGetWrite(t *testing.T) { } var wcc ConsumedCapacity - wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run() + var wrote int + var err error + if isSequential { + wrote, err = table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run() + } else { + wrote, err = table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).RunConcurrently() + } if wrote != batchSize { t.Error("unexpected wrote:", wrote, "≠", batchSize) } @@ -90,3 +96,11 @@ func TestBatchGetWrite(t *testing.T) { t.Error("expected 0 results, got", len(results)) } } + +func TestSequentialBatchGetWrite(t *testing.T) { + testBatchGetWrite(t, true) +} + +func TestConcurrentBatchGetWrite(t *testing.T) { + testBatchGetWrite(t, false) +} diff --git a/batchwrite.go b/batchwrite.go index 76514c8..53824e9 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -1,11 +1,13 @@ package dynamo import ( + "errors" "math" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/cenkalti/backoff" + multierror "github.com/hashicorp/go-multierror" ) // DynamoDB API limit, 25 operations per request @@ -61,6 +63,180 @@ func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite { return bw } +// Structure passed to the concurrent batch write operation +type batchRequest struct { + ctx aws.Context + ops []*dynamodb.WriteRequest +} + +// Structure returned after a concurrent batch operation +type batchResponse struct { + Result *dynamodb.BatchWriteItemOutput + Error error + Wrote int +} + +// Config used when calling RunConcurrently +type batchWriteConfig struct { + poolSize int +} + +// Parameter type to be passed to RunConcurrently +type BatchWriteOption func(*batchWriteConfig) + +// Sets the default config +func defaults(cfg *batchWriteConfig) { + cfg.poolSize = 10 +} + +// Sets the pool size to process the request +func WithPoolSize(poolSize int) BatchWriteOption { + return func(cfg *batchWriteConfig) { + cfg.poolSize = poolSize + } +} + +func (bw *BatchWrite) writeBatch(ctx aws.Context, ops []*dynamodb.WriteRequest) batchResponse { + + boff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) + wrote := 0 + + for { + var res *dynamodb.BatchWriteItemOutput + req := bw.input(ops) + err := retry(ctx, func() error { + var err error + res, err = bw.batch.table.db.client.BatchWriteItemWithContext(ctx, req) + return err + }) + if err != nil { + return batchResponse{ + Result: res, + Error: err, + Wrote: 0, + } + } + if bw.cc != nil { + for _, cc := range res.ConsumedCapacity { + addConsumedCapacity(bw.cc, cc) + } + } + + unprocessed := res.UnprocessedItems[bw.batch.table.Name()] + wrote = len(ops) - len(unprocessed) + if len(unprocessed) == 0 { + return batchResponse{ + Result: res, + Error: err, + Wrote: wrote, + } + } + ops = unprocessed + + // need to sleep when re-requesting, per spec + if err := aws.SleepWithContext(ctx, boff.NextBackOff()); err != nil { + return batchResponse{ + Result: nil, + Error: err, + Wrote: wrote, + } + } + } +} + +func (bw *BatchWrite) writeBatchWorker(worker int, requests <-chan batchRequest, response chan<- batchResponse) { + for request := range requests { + response <- bw.writeBatch(request.ctx, request.ops) + } +} + +func splitBatches(requests []*dynamodb.WriteRequest) (batches [][]*dynamodb.WriteRequest) { + batches = [][]*dynamodb.WriteRequest{} + requestsLength := len(requests) + for i := 0; i < requestsLength; i += maxWriteOps { + end := i + maxWriteOps + if end > requestsLength { + end = requestsLength + } + batches = append(batches, requests[i:end]) + } + return batches +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +// RunConcurrently executes this batch concurrently with the pool size specified. +// By default, the pool size is 10 +func (bw *BatchWrite) RunConcurrently(opts ...BatchWriteOption) (wrote int, err error) { + ctx, cancel := defaultContext() + defer cancel() + return bw.RunConcurrentlyWithContext(ctx, opts...) +} + +func (bw *BatchWrite) RunConcurrentlyWithContext(ctx aws.Context, opts ...BatchWriteOption) (wrote int, err error) { + + if bw.err != nil { + return 0, bw.err + } + + cfg := new(batchWriteConfig) + defaults(cfg) + for _, fn := range opts { + fn(cfg) + } + + // TODO : Can split the batches and run them concurrently ? + batches := splitBatches(bw.ops) + totalBatches := len(batches) + + requests := make(chan batchRequest, totalBatches) + response := make(chan batchResponse, totalBatches) + defer close(response) + + // Create the workers + for i := 0; i < cfg.poolSize; i++ { + go bw.writeBatchWorker(i, requests, response) + } + + // Push the write requests + for i := 0; i < totalBatches; i++ { + requests <- batchRequest{ + ctx: ctx, + ops: batches[i], + } + } + close(requests) + + // Capture the response + wrote = 0 + batchCounter := 0 + for { + select { + case batchResponse, ok := <-response: + if !ok { + err = multierror.Append(err, errors.New("channel unexpectedly closed")) + return wrote, err + } + if batchResponse.Error != nil { + err = multierror.Append(err, batchResponse.Error) + } + wrote += batchResponse.Wrote + batchCounter++ + if batchCounter == totalBatches { + return wrote, err + } + case <-ctx.Done(): + err = multierror.Append(err, ctx.Err()) + return wrote, err + } + } +} + // Run executes this batch. // For batches with more than 25 operations, an error could indicate that // some records have been written and some have not. Consult the wrote @@ -68,6 +244,7 @@ func (bw *BatchWrite) ConsumedCapacity(cc *ConsumedCapacity) *BatchWrite { func (bw *BatchWrite) Run() (wrote int, err error) { ctx, cancel := defaultContext() defer cancel() + // TODO : Perhaps use RunConcurrentlyWithContext(dynamo.WithPoolSize(1)) instead ? return bw.RunWithContext(ctx) } diff --git a/db_test.go b/db_test.go index 210363e..9e33b22 100644 --- a/db_test.go +++ b/db_test.go @@ -18,8 +18,13 @@ var ( const offlineSkipMsg = "DYNAMO_TEST_REGION not set" func init() { - if region := os.Getenv("DYNAMO_TEST_REGION"); region != "" { - testDB = New(session.New(), &aws.Config{Region: aws.String(region)}) + region := os.Getenv("DYNAMO_TEST_REGION") + endpoint := os.Getenv("DYNAMO_ENDPOINT") + if region != "" && endpoint != "" { + testDB = New(session.New(), &aws.Config{ + Region: aws.String(region), + Endpoint: aws.String(endpoint), + }) } if table := os.Getenv("DYNAMO_TEST_TABLE"); table != "" { testTable = table diff --git a/go.mod b/go.mod index c2aff99..dd59197 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ require ( github.com/cenkalti/backoff v2.1.1+incompatible github.com/davecgh/go-spew v1.1.1 // indirect github.com/gofrs/uuid v3.2.0+incompatible + github.com/hashicorp/go-multierror v1.0.0 github.com/stretchr/testify v1.3.0 // indirect golang.org/x/net v0.0.0-20190318221613-d196dffd7c2b ) diff --git a/go.sum b/go.sum index cdd19cb..bb20b72 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..5a8113c --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +docker rm -f dynamodb > /dev/null +docker run --name dynamodb -p 8000:8000 amazon/dynamodb-local > /dev/null & + + +export DYNAMO_ENDPOINT="http://localhost:8000" +export DYNAMO_TEST_REGION="us-west-2" +export DYNAMO_TEST_TABLE="TestDB" + +aws dynamodb delete-table \ +--table-name $DYNAMO_TEST_TABLE \ +--endpoint-url $DYNAMO_ENDPOINT > /dev/null 2>&1 + +aws dynamodb create-table \ + --table-name $DYNAMO_TEST_TABLE \ + --attribute-definitions \ + AttributeName=UserID,AttributeType=N \ + AttributeName=Time,AttributeType=S \ + --key-schema \ + AttributeName=UserID,KeyType=HASH \ + AttributeName=Time,KeyType=RANGE \ + --provisioned-throughput ReadCapacityUnits=1000,WriteCapacityUnits=1000 \ + --region $DYNAMO_TEST_REGION \ + --endpoint-url $DYNAMO_ENDPOINT > /dev/null + +go test . -cover