From 973bed64a4702d8c4b2ca953ddcfbccc8f89a374 Mon Sep 17 00:00:00 2001 From: guregu Date: Fri, 23 Aug 2024 10:55:04 +0900 Subject: [PATCH] add IncludeItemInCondCheckFail and friends --- db.go | 4 ++-- delete.go | 18 ++++++++++++++++-- put.go | 18 ++++++++++++++++-- put_test.go | 5 ++++- tx.go | 42 +++++++++++++++++++++++++++++++++++++----- tx_test.go | 1 + update.go | 18 ++++++++++++++++-- 7 files changed, 92 insertions(+), 14 deletions(-) diff --git a/db.go b/db.go index 4c726ed..1e96f05 100644 --- a/db.go +++ b/db.go @@ -213,7 +213,7 @@ func UnmarshalItemFromCondCheckFailed(condCheckErr error, out any) (match bool, var cfe *types.ConditionalCheckFailedException if errors.As(condCheckErr, &cfe) { if cfe.Item == nil { - return true, fmt.Errorf("dynamo: ConditionalCheckFailedException does not contain item") + return true, fmt.Errorf("dynamo: ConditionalCheckFailedException does not contain item (is IncludeItemInCondCheckFail disabled?): %w", condCheckErr) } return true, UnmarshalItem(cfe.Item, out) } @@ -233,7 +233,7 @@ func UnmarshalItemsFromTxCondCheckFailed(txCancelErr error, out any) (match bool for _, cr := range txe.CancellationReasons { if cr.Code != nil && *cr.Code == "ConditionalCheckFailed" { if cr.Item == nil { - return true, fmt.Errorf("dynamo: TransactionCanceledException.CancellationReasons does not contain item") + return true, fmt.Errorf("dynamo: TransactionCanceledException.CancellationReasons does not contain item (is IncludeItemInCondCheckFail disabled?): %w", txCancelErr) } if err = unmarshal(cr.Item, out); err != nil { return true, err diff --git a/delete.go b/delete.go index d7be2cf..e5dc9a5 100644 --- a/delete.go +++ b/delete.go @@ -11,8 +11,10 @@ import ( // Delete is a request to delete an item. // See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DeleteItem.html type Delete struct { - table Table + table Table + returnType types.ReturnValue + onCondFail types.ReturnValuesOnConditionCheckFailure hashKey string hashValue types.AttributeValue @@ -108,6 +110,7 @@ func (d *Delete) OldValue(ctx context.Context, out interface{}) error { // See also: [UnmarshalItemFromCondCheckFailed]. func (d *Delete) CurrentValue(ctx context.Context, out interface{}) (wrote bool, err error) { d.returnType = types.ReturnValueNone + d.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld _, err = d.run(ctx) if err != nil { if ok, err := UnmarshalItemFromCondCheckFailed(err, out); ok { @@ -118,6 +121,17 @@ func (d *Delete) CurrentValue(ctx context.Context, out interface{}) (wrote bool, return true, nil } +// IncludeAllItemsInCondCheckFail specifies whether an item delete that fails its condition check should include the item itself in the error. +// Such items can be extracted using [UnmarshalItemFromCondCheckFailed] for single deletes, or [UnmarshalItemsFromTxCondCheckFailed] for write transactions. +func (d *Delete) IncludeItemInCondCheckFail(enabled bool) *Delete { + if enabled { + d.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld + } else { + d.onCondFail = types.ReturnValuesOnConditionCheckFailureNone + } + return d +} + func (d *Delete) run(ctx context.Context) (*dynamodb.DeleteItemOutput, error) { if d.err != nil { return nil, d.err @@ -147,7 +161,7 @@ func (d *Delete) deleteInput() *dynamodb.DeleteItemInput { } if d.condition != "" { input.ConditionExpression = &d.condition - input.ReturnValuesOnConditionCheckFailure = types.ReturnValuesOnConditionCheckFailureAllOld + input.ReturnValuesOnConditionCheckFailure = d.onCondFail } if d.cc != nil { input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes diff --git a/put.go b/put.go index 4fd9cf2..5d67b3e 100644 --- a/put.go +++ b/put.go @@ -10,8 +10,10 @@ import ( // Put is a request to create or replace an item. // See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_PutItem.html type Put struct { - table Table + table Table + returnType types.ReturnValue + onCondFail types.ReturnValuesOnConditionCheckFailure item Item subber @@ -83,6 +85,7 @@ func (p *Put) OldValue(ctx context.Context, out interface{}) error { // See also: [UnmarshalItemFromCondCheckFailed]. func (p *Put) CurrentValue(ctx context.Context, out interface{}) (wrote bool, err error) { p.returnType = types.ReturnValueNone + p.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld item, _, err := p.run(ctx) wrote = err == nil if err != nil { @@ -93,6 +96,17 @@ func (p *Put) CurrentValue(ctx context.Context, out interface{}) (wrote bool, er return } +// IncludeAllItemsInCondCheckFail specifies whether an item put that fails its condition check should include the item itself in the error. +// Such items can be extracted using [UnmarshalItemFromCondCheckFailed] for single puts, or [UnmarshalItemsFromTxCondCheckFailed] for write transactions. +func (p *Put) IncludeItemInCondCheckFail(enabled bool) *Put { + if enabled { + p.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld + } else { + p.onCondFail = types.ReturnValuesOnConditionCheckFailureNone + } + return p +} + func (p *Put) run(ctx context.Context) (item Item, output *dynamodb.PutItemOutput, err error) { if p.err != nil { return nil, nil, p.err @@ -121,7 +135,7 @@ func (p *Put) input() *dynamodb.PutItemInput { } if p.condition != "" { input.ConditionExpression = &p.condition - input.ReturnValuesOnConditionCheckFailure = types.ReturnValuesOnConditionCheckFailureAllOld + input.ReturnValuesOnConditionCheckFailure = p.onCondFail } if p.cc != nil { input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes diff --git a/put_test.go b/put_test.go index 7251afa..1a33ed3 100644 --- a/put_test.go +++ b/put_test.go @@ -70,7 +70,10 @@ func TestPut(t *testing.T) { // putting the same item: this should fail t.Run("UnmarshalItemFromCondCheckFailed", func(t *testing.T) { - err := table.Put(newItem).If("attribute_not_exists(UserID)").If("attribute_not_exists('Time')").Run(ctx) + err := table.Put(newItem). + If("attribute_not_exists(UserID)"). + If("attribute_not_exists('Time')"). + IncludeItemInCondCheckFail(true).Run(ctx) if !IsCondCheckFailed(err) { t.Error("expected ConditionalCheckFailedException, not", err) } diff --git a/tx.go b/tx.go index 7679f31..45acaa3 100644 --- a/tx.go +++ b/tx.go @@ -168,11 +168,12 @@ type writeTxOp interface { // WriteTx is analogous to TransactWriteItems in DynamoDB's API. // See: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TransactWriteItems.html type WriteTx struct { - db *DB - items []writeTxOp - token string - cc *ConsumedCapacity - err error + db *DB + items []writeTxOp + token string + onCondFail types.ReturnValuesOnConditionCheckFailure + cc *ConsumedCapacity + err error } // WriteTx begins a new write transaction. @@ -206,6 +207,20 @@ func (tx *WriteTx) Check(check *ConditionCheck) *WriteTx { return tx } +// IncludeAllItemsInCondCheckFail specifies whether an item write that fails its condition check should include the item itself in the error. +// Such items can be extracted using [UnmarshalItemsFromTxCondCheckFailed]. +// +// By default, the individual settings for each item are respected. +// Calling this will override all individual settings. +func (tx *WriteTx) IncludeAllItemsInCondCheckFail(enabled bool) *WriteTx { + if enabled { + tx.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld + } else { + tx.onCondFail = types.ReturnValuesOnConditionCheckFailureNone + } + return tx +} + // Idempotent marks this transaction as idempotent when enabled is true. // This automatically generates a unique idempotency token for you. // An idempotent transaction ran multiple times will have the same effect as being run once. @@ -279,6 +294,7 @@ func (tx *WriteTx) input() (*dynamodb.TransactWriteItemsInput, error) { if err != nil { return nil, err } + setTWIReturnType(wti, tx.onCondFail) input.TransactItems = append(input.TransactItems, *wti) } if tx.token != "" { @@ -290,6 +306,22 @@ func (tx *WriteTx) input() (*dynamodb.TransactWriteItemsInput, error) { return input, nil } +func setTWIReturnType(wti *types.TransactWriteItem, ret types.ReturnValuesOnConditionCheckFailure) { + if ret == "" { + return + } + switch { + case wti.ConditionCheck != nil: + wti.ConditionCheck.ReturnValuesOnConditionCheckFailure = ret + case wti.Delete != nil: + wti.Delete.ReturnValuesOnConditionCheckFailure = ret + case wti.Put != nil: + wti.Put.ReturnValuesOnConditionCheckFailure = ret + case wti.Update != nil: + wti.Update.ReturnValuesOnConditionCheckFailure = ret + } +} + func (tx *WriteTx) setError(err error) { if tx.err == nil { tx.err = err diff --git a/tx_test.go b/tx_test.go index f73ac8a..9b89c2d 100644 --- a/tx_test.go +++ b/tx_test.go @@ -143,6 +143,7 @@ func TestTx(t *testing.T) { tx := testDB.WriteTx() tx.Put(table.Put(widget{UserID: 69, Time: date1}).If("'BadField' = ?", "should not exist")) tx.Put(table.Put(widget{UserID: 69, Time: date2}).If("'BadField' = ?", "should not exist")) + tx.IncludeAllItemsInCondCheckFail(true) err := tx.Run(ctx) if err == nil { t.Fatal("expected error") diff --git a/update.go b/update.go index 24bec84..17757a6 100644 --- a/update.go +++ b/update.go @@ -13,8 +13,10 @@ import ( // It uses the UpdateItem API. // See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_UpdateItem.html type Update struct { - table Table + table Table + returnType types.ReturnValue + onCondFail types.ReturnValuesOnConditionCheckFailure hashKey string hashValue types.AttributeValue @@ -347,6 +349,7 @@ func (u *Update) OnlyUpdatedOldValue(ctx context.Context, out interface{}) error // See also: [UnmarshalItemFromCondCheckFailed]. func (u *Update) CurrentValue(ctx context.Context, out interface{}) (wrote bool, err error) { u.returnType = types.ReturnValueAllNew + u.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld output, err := u.run(ctx) if err != nil { if ok, err := UnmarshalItemFromCondCheckFailed(err, out); ok { @@ -357,6 +360,17 @@ func (u *Update) CurrentValue(ctx context.Context, out interface{}) (wrote bool, return true, unmarshalItem(output.Attributes, out) } +// IncludeAllItemsInCondCheckFail specifies whether an item update that fails its condition check should include the item itself in the error. +// Such items can be extracted using [UnmarshalItemFromCondCheckFailed] for single updates, or [UnmarshalItemsFromTxCondCheckFailed] for write transactions. +func (u *Update) IncludeItemInCondCheckFail(enabled bool) *Update { + if enabled { + u.onCondFail = types.ReturnValuesOnConditionCheckFailureAllOld + } else { + u.onCondFail = types.ReturnValuesOnConditionCheckFailureNone + } + return u +} + func (u *Update) run(ctx context.Context) (*dynamodb.UpdateItemOutput, error) { if u.err != nil { return nil, u.err @@ -387,7 +401,7 @@ func (u *Update) updateInput() *dynamodb.UpdateItemInput { } if u.condition != "" { input.ConditionExpression = &u.condition - input.ReturnValuesOnConditionCheckFailure = types.ReturnValuesOnConditionCheckFailureAllOld + input.ReturnValuesOnConditionCheckFailure = u.onCondFail } if u.cc != nil { input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes