From 6f908058e50bb7658c789a6b62f5b3b3320d1c41 Mon Sep 17 00:00:00 2001 From: Alisha Date: Tue, 7 Nov 2023 09:35:36 -0800 Subject: [PATCH] tests --- backend/internal/nucleusdb/mock_Tx.go | 69 +++++++++ backend/internal/nucleusdb/users.go | 2 +- backend/internal/nucleusdb/users_test.go | 142 +++++++++++------- .../user-account-service/users_test.go | 28 ++++ 4 files changed, 184 insertions(+), 57 deletions(-) create mode 100644 backend/internal/nucleusdb/mock_Tx.go diff --git a/backend/internal/nucleusdb/mock_Tx.go b/backend/internal/nucleusdb/mock_Tx.go new file mode 100644 index 0000000000..1edee167ed --- /dev/null +++ b/backend/internal/nucleusdb/mock_Tx.go @@ -0,0 +1,69 @@ +package nucleusdb + +import ( + "context" + + "github.com/jackc/pgx/v5" + pgconn "github.com/jackc/pgx/v5/pgconn" + mock "github.com/stretchr/testify/mock" +) + +// MockTx is a mock type for the Tx interface +type MockTx struct { + mock.Mock +} + +func (m *MockTx) Begin(ctx context.Context) (pgx.Tx, error) { + args := m.Called(ctx) + return args.Get(0).(pgx.Tx), args.Error(1) +} + +func (m *MockTx) Commit(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockTx) Rollback(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockTx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + args := m.Called(ctx, tableName, columnNames, rowSrc) + return args.Get(0).(int64), args.Error(1) +} + +func (m *MockTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + args := m.Called(ctx, b) + return args.Get(0).(pgx.BatchResults) +} + +func (m *MockTx) LargeObjects() pgx.LargeObjects { + args := m.Called() + return args.Get(0).(pgx.LargeObjects) +} + +func (m *MockTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + args := m.Called(ctx, name, sql) + return args.Get(0).(*pgconn.StatementDescription), args.Error(1) +} + +func (m *MockTx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + args := m.Called(ctx, sql, arguments) + return args.Get(0).(pgconn.CommandTag), args.Error(1) +} + +func (m *MockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + callArgs := m.Called(ctx, sql, args) + return callArgs.Get(0).(pgx.Rows), callArgs.Error(1) +} + +func (m *MockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + callArgs := m.Called(ctx, sql, args) + return callArgs.Get(0).(pgx.Row) +} + +func (m *MockTx) Conn() *pgx.Conn { + args := m.Called() + return args.Get(0).(*pgx.Conn) +} diff --git a/backend/internal/nucleusdb/users.go b/backend/internal/nucleusdb/users.go index be7ea1f9c9..0b5fb4ff95 100644 --- a/backend/internal/nucleusdb/users.go +++ b/backend/internal/nucleusdb/users.go @@ -114,7 +114,7 @@ func (d *NucleusDb) CreateTeamAccount( ) (*db_queries.NeosyncApiAccount, error) { var teamAccount *db_queries.NeosyncApiAccount if err := d.WithTx(ctx, nil, func(dbtx BaseDBTX) error { - accounts, err := d.Q.GetTeamAccountsByUserId(ctx, dbtx, userId) + accounts, err := d.Q.GetAccountsByUser(ctx, dbtx, userId) if err != nil && !IsNoRows(err) { return err } else if err != nil && IsNoRows(err) { diff --git a/backend/internal/nucleusdb/users_test.go b/backend/internal/nucleusdb/users_test.go index 84a85061fa..a62fa32d02 100644 --- a/backend/internal/nucleusdb/users_test.go +++ b/backend/internal/nucleusdb/users_test.go @@ -2,12 +2,12 @@ package nucleusdb import ( "context" + "database/sql" + "errors" "testing" - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v5" db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" - mock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/mock" "github.com/zeebo/assert" ) @@ -15,88 +15,118 @@ const ( anonymousUserId = "00000000-0000-0000-0000-000000000000" mockUserId = "d5e29f1f-b920-458c-8b86-f3a180e06d98" mockAccountId = "5629813e-1a35-4874-922c-9827d85f0378" + mockTeamName = "team-name" ) -// MockTx is a mock type for the Tx interface -type MockTx struct { - mock.Mock -} +// CreateTeamAccount +func Test_CreateTeamAccount(t *testing.T) { + dbtxMock := NewMockDBTX(t) + querierMock := db_queries.NewMockQuerier(t) + mockTx := new(MockTx) -func (m *MockTx) Begin(ctx context.Context) (pgx.Tx, error) { - args := m.Called(ctx) - return args.Get(0).(pgx.Tx), args.Error(1) -} + userUuid, _ := ToUuid(mockUserId) + accountUuid, _ := ToUuid(mockAccountId) + ctx := context.Background() -func (m *MockTx) Commit(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} + service := New(dbtxMock, querierMock) -func (m *MockTx) Rollback(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} + dbtxMock.On("Begin", ctx).Return(mockTx, nil) + querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil) + querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil) + querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{ + AccountID: accountUuid, + UserID: userUuid, + }).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil) + mockTx.On("Commit", ctx).Return(nil) + mockTx.On("Rollback", ctx).Return(nil) -func (m *MockTx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { - args := m.Called(ctx, tableName, columnNames, rowSrc) - return args.Get(0).(int64), args.Error(1) -} + resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName) -func (m *MockTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { - args := m.Called(ctx, b) - return args.Get(0).(pgx.BatchResults) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, accountUuid, resp.ID) + assert.Equal(t, mockTeamName, resp.AccountSlug) } -func (m *MockTx) LargeObjects() pgx.LargeObjects { - args := m.Called() - return args.Get(0).(pgx.LargeObjects) -} +func Test_CreateTeamAccount_AlreadyExists(t *testing.T) { + dbtxMock := NewMockDBTX(t) + querierMock := db_queries.NewMockQuerier(t) + mockTx := new(MockTx) -func (m *MockTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { - args := m.Called(ctx, name, sql) - return args.Get(0).(*pgconn.StatementDescription), args.Error(1) -} + userUuid, _ := ToUuid(mockUserId) + ctx := context.Background() -func (m *MockTx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - args := m.Called(ctx, sql, arguments) - return args.Get(0).(pgconn.CommandTag), args.Error(1) -} + service := New(dbtxMock, querierMock) -func (m *MockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { - callArgs := m.Called(ctx, sql, args) - return callArgs.Get(0).(pgx.Rows), callArgs.Error(1) -} + dbtxMock.On("Begin", ctx).Return(mockTx, nil) + querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: mockTeamName}}, nil) + mockTx.On("Rollback", ctx).Return(nil) -func (m *MockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { - callArgs := m.Called(ctx, sql, args) - return callArgs.Get(0).(pgx.Row) -} + resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName) -func (m *MockTx) Conn() *pgx.Conn { - args := m.Called() - return args.Get(0).(*pgx.Conn) + querierMock.AssertNotCalled(t, "CreateTeamAccount", mock.Anything, mock.Anything, mock.Anything) + querierMock.AssertNotCalled(t, "CreateAccountUserAssociation", mock.Anything, mock.Anything, mock.Anything) + mockTx.AssertNotCalled(t, "Commit", mock.Anything) + + assert.Error(t, err) + assert.Nil(t, resp) } -// CreateTeamAccount -func Test_GetUser_Anonymous(t *testing.T) { +func Test_CreateTeamAccount_NoRows(t *testing.T) { dbtxMock := NewMockDBTX(t) querierMock := db_queries.NewMockQuerier(t) mockTx := new(MockTx) userUuid, _ := ToUuid(mockUserId) accountUuid, _ := ToUuid(mockAccountId) - teamName := "team-name" + var nilAccounts []db_queries.NeosyncApiAccount ctx := context.Background() service := New(dbtxMock, querierMock) dbtxMock.On("Begin", ctx).Return(mockTx, nil) - querierMock.On("GetTeamAccountsByUserId", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{}, nil) - querierMock.On("CreateTeamAccount", ctx, mockTx, teamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid}, nil) - querierMock.On("CreateAccountUserAssociation", ctx, mockTx).Return(nil, nil) + querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return(nilAccounts, sql.ErrNoRows) + querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil) + querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{ + AccountID: accountUuid, + UserID: userUuid, + }).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil) + mockTx.On("Commit", ctx).Return(nil) + mockTx.On("Rollback", ctx).Return(nil) - resp, err := service.CreateTeamAccount(context.Background(), userUuid, teamName) + resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName) assert.NoError(t, err) assert.NotNil(t, resp) + assert.Equal(t, accountUuid, resp.ID) + assert.Equal(t, mockTeamName, resp.AccountSlug) +} + +func Test_CreateTeamAccount_Rollback(t *testing.T) { + dbtxMock := NewMockDBTX(t) + querierMock := db_queries.NewMockQuerier(t) + mockTx := new(MockTx) + + userUuid, _ := ToUuid(mockUserId) + accountUuid, _ := ToUuid(mockAccountId) + ctx := context.Background() + var nilAssociation db_queries.NeosyncApiAccountUserAssociation + + service := New(dbtxMock, querierMock) + + dbtxMock.On("Begin", ctx).Return(mockTx, nil) + querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil) + querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil) + querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{ + AccountID: accountUuid, + UserID: userUuid, + }).Return(nilAssociation, errors.New("sad")) + mockTx.On("Rollback", ctx).Return(nil) + + resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName) + + mockTx.AssertCalled(t, "Rollback", ctx) + mockTx.AssertNotCalled(t, "Commit", mock.Anything) + assert.Error(t, err) + assert.Nil(t, resp) } diff --git a/backend/services/mgmt/v1alpha1/user-account-service/users_test.go b/backend/services/mgmt/v1alpha1/user-account-service/users_test.go index 4d86ad8eb2..53987c7040 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/users_test.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/users_test.go @@ -241,6 +241,34 @@ func Test_IsUserInAccount_False(t *testing.T) { assert.Equal(t, false, resp.Msg.Ok) } +func Test_CreateTeamAccount(t *testing.T) { + m := createServiceMock(t, &Config{IsAuthEnabled: true}) + mockTx := new(nucleusdb.MockTx) + + mockTeamName := "team-name" + ctx := getAuthenticatedCtxMock(mockAuthProvider) + userAssociation := getUserIdentityProviderAssociationMock(mockUserId, mockAuthProvider) + accountUuid, _ := nucleusdb.ToUuid(mockAccountId) + userUuid, _ := nucleusdb.ToUuid(mockUserId) + m.QuerierMock.On("GetUserAssociationByAuth0Id", ctx, mock.Anything, mockAuthProvider).Return(userAssociation, nil) + m.DbtxMock.On("Begin", ctx).Return(mockTx, nil) + m.QuerierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil) + m.QuerierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil) + m.QuerierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{ + AccountID: accountUuid, + UserID: userUuid, + }).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil) + mockTx.On("Commit", ctx).Return(nil) + mockTx.On("Rollback", ctx).Return(nil) + + resp, err := m.Service.CreateTeamAccount(ctx, &connect.Request[mgmtv1alpha1.CreateTeamAccountRequest]{Msg: &mgmtv1alpha1.CreateTeamAccountRequest{Name: mockTeamName}}) + + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, mockAccountId, resp.Msg.AccountId) + +} + type serviceMocks struct { Service *Service DbtxMock *nucleusdb.MockDBTX