diff --git a/encoding.go b/encoding.go index ec27a3b..f619e4a 100644 --- a/encoding.go +++ b/encoding.go @@ -16,7 +16,6 @@ var typeCache sync.Map // unmarshalKey → *typedef type typedef struct { decoders map[unmarshalKey]decodeFunc fields []structField - root reflect.Type info *structInfo } @@ -24,7 +23,6 @@ func newTypedef(rt reflect.Type) (*typedef, error) { def := &typedef{ decoders: make(map[unmarshalKey]decodeFunc), // encoders: make(map[encodeKey]encodeFunc), - root: rt, } err := def.init(rt) return def, err diff --git a/query.go b/query.go index 498e65e..19f7b4e 100644 --- a/query.go +++ b/query.go @@ -185,10 +185,11 @@ func (q *Query) SearchLimit(limit int64) *Query { } // RequestLimit specifies the maximum amount of requests to make against DynamoDB's API. -// func (q *Query) RequestLimit(limit int) *Query { -// q.reqLimit = limit -// return q -// } +// A limit of zero or less means unlimited requests. +func (q *Query) RequestLimit(limit int) *Query { + q.reqLimit = limit + return q +} // Order specifies the desired result order. // Requires a range key (a.k.a. partition key) to be specified. @@ -286,22 +287,29 @@ func (q *Query) CountWithContext(ctx context.Context) (int64, error) { return 0, q.err } - var count int64 + var count, scanned int64 + var reqs int var res *dynamodb.QueryOutput for { - req := q.queryInput() - req.Select = selectCount + input := q.queryInput() + input.Select = selectCount err := q.table.db.retry(ctx, func() error { var err error - res, err = q.table.db.client.QueryWithContext(ctx, req) + res, err = q.table.db.client.QueryWithContext(ctx, input) if err != nil { return err } + reqs++ + if res.Count == nil { - return errors.New("nil count") + return errors.New("malformed DynamoDB response: count is nil") } count += *res.Count + if res.ScannedCount != nil { + scanned += *res.ScannedCount + } + return nil }) if err != nil { @@ -312,7 +320,10 @@ func (q *Query) CountWithContext(ctx context.Context) (int64, error) { } q.startKey = res.LastEvaluatedKey - if res.LastEvaluatedKey == nil || q.searchLimit > 0 { + if res.LastEvaluatedKey == nil || + (q.limit > 0 && count >= q.limit) || + (q.searchLimit > 0 && scanned >= q.searchLimit) || + (q.reqLimit > 0 && reqs >= q.reqLimit) { break } } diff --git a/reflect.go b/reflect.go index 1c70bd1..cd59925 100644 --- a/reflect.go +++ b/reflect.go @@ -190,11 +190,11 @@ type encodeKey struct { type structInfo struct { root reflect.Type + parent *structInfo fields map[string]*structField // by name refs map[encodeKey][]*structField types map[encodeKey]encodeFunc zeros map[reflect.Type]func(reflect.Value) bool - parent *structInfo seen map[encodeKey]struct{} queue []encodeKey diff --git a/scan.go b/scan.go index 44e3e40..2facf32 100644 --- a/scan.go +++ b/scan.go @@ -2,6 +2,7 @@ package dynamo import ( "context" + "errors" "strings" "sync" @@ -131,10 +132,11 @@ func (s *Scan) SearchLimit(limit int64) *Scan { } // RequestLimit specifies the maximum amount of requests to make against DynamoDB's API. -// func (s *Scan) RequestLimit(limit int) *Scan { -// s.reqLimit = limit -// return s -// } +// A limit of zero or less means unlimited requests. +func (s *Scan) RequestLimit(limit int) *Scan { + s.reqLimit = limit + return s +} // ConsumedCapacity will measure the throughput capacity consumed by this operation and add it to cc. func (s *Scan) ConsumedCapacity(cc *ConsumedCapacity) *Scan { @@ -260,6 +262,7 @@ func (s *Scan) CountWithContext(ctx context.Context) (int64, error) { var count, scanned int64 input := s.scanInput() input.Select = aws.String(dynamodb.SelectCount) + var reqs int for { var out *dynamodb.ScanOutput err := s.table.db.retry(ctx, func() error { @@ -268,23 +271,26 @@ func (s *Scan) CountWithContext(ctx context.Context) (int64, error) { return err }) if err != nil { - return count, err + return 0, err } + reqs++ + if out.Count == nil { + return count, errors.New("malformed DynamoDB outponse: count is nil") + } count += *out.Count - scanned += *out.ScannedCount + if out.ScannedCount != nil { + scanned += *out.ScannedCount + } if s.cc != nil { addConsumedCapacity(s.cc, out.ConsumedCapacity) } - if s.limit > 0 && count >= s.limit { - break - } - if s.searchLimit > 0 && scanned >= s.searchLimit { - break - } - if out.LastEvaluatedKey == nil { + if out.LastEvaluatedKey == nil || + (s.limit > 0 && count >= s.limit) || + (s.searchLimit > 0 && scanned >= s.searchLimit) || + (s.reqLimit > 0 && reqs >= s.reqLimit) { break }