Skip to content

Commit

Permalink
add tests for dynamodb and fix condition when checking if table exists (
Browse files Browse the repository at this point in the history
  • Loading branch information
larhauga authored May 9, 2024
1 parent 2dd7c78 commit a9a4bd3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 18 deletions.
42 changes: 24 additions & 18 deletions cli/pkg/aws/dynamo_locking.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,23 @@ const (
)

type DynamoDbLock struct {
DynamoDb *dynamodb.Client
DynamoDb DynamoDBClient
}

func isResourceNotFoundExceptionError(err error) bool {
if err != nil {
var apiError smithy.APIError
if errors.As(err, &apiError) {
switch apiError.(type) {
case *types.ResourceNotFoundException:
return true
}
type DynamoDBClient interface {
DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error)
CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error)
UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error)
DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error)
GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error)
}

func isTableNotFoundExceptionError(err error) bool {
var apiError smithy.APIError
if errors.As(err, &apiError) {
switch apiError.(type) {
case *types.TableNotFoundException:
return true
}
}
return false
Expand All @@ -49,7 +55,7 @@ func (dynamoDbLock *DynamoDbLock) waitUntilTableCreated(ctx context.Context) err
cnt := 0

if err != nil {
if !isResourceNotFoundExceptionError(err) {
if !isTableNotFoundExceptionError(err) {
return err
}
}
Expand All @@ -58,7 +64,7 @@ func (dynamoDbLock *DynamoDbLock) waitUntilTableCreated(ctx context.Context) err
time.Sleep(TableCreationInterval)
status, err = dynamoDbLock.DynamoDb.DescribeTable(ctx, input)
if err != nil {
if !isResourceNotFoundExceptionError(err) {
if !isTableNotFoundExceptionError(err) {
return err
}
}
Expand All @@ -78,15 +84,14 @@ func (dynamoDbLock *DynamoDbLock) createTableIfNotExists(ctx context.Context) er
_, err := dynamoDbLock.DynamoDb.DescribeTable(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(TABLE_NAME),
})

if err != nil {
if !isResourceNotFoundExceptionError(err) {
return err
}
if err == nil { // Table exists
return nil
}
if !isTableNotFoundExceptionError(err) {
return err
}

createtbl_input := &dynamodb.CreateTableInput{

AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("PK"),
Expand Down Expand Up @@ -214,7 +219,8 @@ func (dynamoDbLock *DynamoDbLock) GetLock(lockId string) (*int, error) {
}

type TransactionLock struct {
TransactionID int `dynamodbav:"transaction_id"`
TransactionID int `dynamodbav:"transaction_id"`
Timeout string `dynamodbav:"timeout"`
}

var t TransactionLock
Expand Down
91 changes: 91 additions & 0 deletions cli/pkg/aws/dynamo_locking_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package aws

import (
"context"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)

type mockDynamoDbClient struct {
table map[string]map[string]types.AttributeValue
Options dynamodb.Options
MockDescribeTable func(ctx context.Context, params dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error)
MockUpdateItem func(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error)
MockGetItem func(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error)
MockDeleteItem func(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error)
}

func (m *mockDynamoDbClient) DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) {
if m.table == nil || m.table[aws.ToString(params.TableName)] == nil {
return nil, &types.TableNotFoundException{}
}
if m.table[aws.ToString(params.TableName)] != nil {
return &dynamodb.DescribeTableOutput{Table: &types.TableDescription{TableName: params.TableName}}, nil
}
return nil, nil
}

func (m *mockDynamoDbClient) CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) {
m.table[aws.ToString(params.TableName)] = make(map[string]types.AttributeValue)
return nil, nil
}

func (m *mockDynamoDbClient) UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) {
// TODO: Implement this
return &dynamodb.UpdateItemOutput{}, nil
}

func (m *mockDynamoDbClient) GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) {
return &dynamodb.GetItemOutput{
Item: map[string]types.AttributeValue{
"PK": &types.AttributeValueMemberS{Value: "LOCK"},
"SK": &types.AttributeValueMemberS{Value: "RES#example-resource"},
"transaction_id": &types.AttributeValueMemberN{Value: "123"},
"timeout": &types.AttributeValueMemberS{Value: "2024-04-01T00:00:00Z"},
},
}, nil
}

func (m *mockDynamoDbClient) DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) {
m.table[aws.ToString(params.TableName)][aws.ToString(&params.Key["SK"].(*types.AttributeValueMemberS).Value)] = nil
return &dynamodb.DeleteItemOutput{}, nil
}

func TestDynamoDbLock_Lock(t *testing.T) {
client := mockDynamoDbClient{table: make(map[string]map[string]types.AttributeValue)}
dynamodbLock := DynamoDbLock{
DynamoDb: &client,
}
dynamodbLock.DynamoDb.CreateTable(context.Background(), &dynamodb.CreateTableInput{TableName: aws.String(TABLE_NAME)})

// Set up the input parameters for the Lock method
transactionId := 123
resource := "example-resource"

locked, err := dynamodbLock.Lock(transactionId, resource)
if err != nil {
t.Fatalf("Error: %v", err)
}
if !locked {
t.Fatalf("Expected true, got %v", locked)
}
}
func TestDynamoDbLock_GetLock(t *testing.T) {
// Create a mock DynamoDB client
client := mockDynamoDbClient{table: make(map[string]map[string]types.AttributeValue)}
dynamodbLock := DynamoDbLock{
DynamoDb: &client,
}
dynamodbLock.DynamoDb.CreateTable(context.Background(), &dynamodb.CreateTableInput{TableName: aws.String(TABLE_NAME)})

id, err := dynamodbLock.GetLock("example-resource")
if err != nil {
t.Fatalf("Error: %v", err)
}
if *id != 123 {
t.Fatalf("Expected 123, got %v", id)
}
}

0 comments on commit a9a4bd3

Please sign in to comment.