Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi committed Nov 7, 2023
1 parent bbaa41c commit 6f90805
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 57 deletions.
69 changes: 69 additions & 0 deletions backend/internal/nucleusdb/mock_Tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package nucleusdb

import (
"context"

"github.com/jackc/pgx/v5"
pgconn "github.com/jackc/pgx/v5/pgconn"
mock "github.com/stretchr/testify/mock"
)

// MockTx is a mock type for the Tx interface
type MockTx struct {
mock.Mock
}

func (m *MockTx) Begin(ctx context.Context) (pgx.Tx, error) {
args := m.Called(ctx)
return args.Get(0).(pgx.Tx), args.Error(1)
}

func (m *MockTx) Commit(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}

func (m *MockTx) Rollback(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}

func (m *MockTx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
args := m.Called(ctx, tableName, columnNames, rowSrc)
return args.Get(0).(int64), args.Error(1)
}

func (m *MockTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
args := m.Called(ctx, b)
return args.Get(0).(pgx.BatchResults)
}

func (m *MockTx) LargeObjects() pgx.LargeObjects {
args := m.Called()
return args.Get(0).(pgx.LargeObjects)
}

func (m *MockTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
args := m.Called(ctx, name, sql)
return args.Get(0).(*pgconn.StatementDescription), args.Error(1)
}

func (m *MockTx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
args := m.Called(ctx, sql, arguments)
return args.Get(0).(pgconn.CommandTag), args.Error(1)
}

func (m *MockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
callArgs := m.Called(ctx, sql, args)
return callArgs.Get(0).(pgx.Rows), callArgs.Error(1)
}

func (m *MockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
callArgs := m.Called(ctx, sql, args)
return callArgs.Get(0).(pgx.Row)
}

func (m *MockTx) Conn() *pgx.Conn {
args := m.Called()
return args.Get(0).(*pgx.Conn)
}
2 changes: 1 addition & 1 deletion backend/internal/nucleusdb/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (d *NucleusDb) CreateTeamAccount(
) (*db_queries.NeosyncApiAccount, error) {
var teamAccount *db_queries.NeosyncApiAccount
if err := d.WithTx(ctx, nil, func(dbtx BaseDBTX) error {
accounts, err := d.Q.GetTeamAccountsByUserId(ctx, dbtx, userId)
accounts, err := d.Q.GetAccountsByUser(ctx, dbtx, userId)
if err != nil && !IsNoRows(err) {
return err

Check warning on line 119 in backend/internal/nucleusdb/users.go

View check run for this annotation

Codecov / codecov/patch

backend/internal/nucleusdb/users.go#L119

Added line #L119 was not covered by tests
} else if err != nil && IsNoRows(err) {
Expand Down
142 changes: 86 additions & 56 deletions backend/internal/nucleusdb/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,101 +2,131 @@ package nucleusdb

import (
"context"
"database/sql"
"errors"
"testing"

"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5"
db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db"
mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/mock"
"github.com/zeebo/assert"
)

const (
anonymousUserId = "00000000-0000-0000-0000-000000000000"
mockUserId = "d5e29f1f-b920-458c-8b86-f3a180e06d98"
mockAccountId = "5629813e-1a35-4874-922c-9827d85f0378"
mockTeamName = "team-name"
)

// MockTx is a mock type for the Tx interface
type MockTx struct {
mock.Mock
}
// CreateTeamAccount
func Test_CreateTeamAccount(t *testing.T) {
dbtxMock := NewMockDBTX(t)
querierMock := db_queries.NewMockQuerier(t)
mockTx := new(MockTx)

func (m *MockTx) Begin(ctx context.Context) (pgx.Tx, error) {
args := m.Called(ctx)
return args.Get(0).(pgx.Tx), args.Error(1)
}
userUuid, _ := ToUuid(mockUserId)
accountUuid, _ := ToUuid(mockAccountId)
ctx := context.Background()

func (m *MockTx) Commit(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
service := New(dbtxMock, querierMock)

func (m *MockTx) Rollback(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
dbtxMock.On("Begin", ctx).Return(mockTx, nil)
querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil)
querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil)
querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{
AccountID: accountUuid,
UserID: userUuid,
}).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil)
mockTx.On("Commit", ctx).Return(nil)
mockTx.On("Rollback", ctx).Return(nil)

func (m *MockTx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
args := m.Called(ctx, tableName, columnNames, rowSrc)
return args.Get(0).(int64), args.Error(1)
}
resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName)

func (m *MockTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
args := m.Called(ctx, b)
return args.Get(0).(pgx.BatchResults)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, accountUuid, resp.ID)
assert.Equal(t, mockTeamName, resp.AccountSlug)
}

func (m *MockTx) LargeObjects() pgx.LargeObjects {
args := m.Called()
return args.Get(0).(pgx.LargeObjects)
}
func Test_CreateTeamAccount_AlreadyExists(t *testing.T) {
dbtxMock := NewMockDBTX(t)
querierMock := db_queries.NewMockQuerier(t)
mockTx := new(MockTx)

func (m *MockTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
args := m.Called(ctx, name, sql)
return args.Get(0).(*pgconn.StatementDescription), args.Error(1)
}
userUuid, _ := ToUuid(mockUserId)
ctx := context.Background()

func (m *MockTx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
args := m.Called(ctx, sql, arguments)
return args.Get(0).(pgconn.CommandTag), args.Error(1)
}
service := New(dbtxMock, querierMock)

func (m *MockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
callArgs := m.Called(ctx, sql, args)
return callArgs.Get(0).(pgx.Rows), callArgs.Error(1)
}
dbtxMock.On("Begin", ctx).Return(mockTx, nil)
querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: mockTeamName}}, nil)
mockTx.On("Rollback", ctx).Return(nil)

func (m *MockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
callArgs := m.Called(ctx, sql, args)
return callArgs.Get(0).(pgx.Row)
}
resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName)

func (m *MockTx) Conn() *pgx.Conn {
args := m.Called()
return args.Get(0).(*pgx.Conn)
querierMock.AssertNotCalled(t, "CreateTeamAccount", mock.Anything, mock.Anything, mock.Anything)
querierMock.AssertNotCalled(t, "CreateAccountUserAssociation", mock.Anything, mock.Anything, mock.Anything)
mockTx.AssertNotCalled(t, "Commit", mock.Anything)

assert.Error(t, err)
assert.Nil(t, resp)
}

// CreateTeamAccount
func Test_GetUser_Anonymous(t *testing.T) {
func Test_CreateTeamAccount_NoRows(t *testing.T) {
dbtxMock := NewMockDBTX(t)
querierMock := db_queries.NewMockQuerier(t)
mockTx := new(MockTx)

userUuid, _ := ToUuid(mockUserId)
accountUuid, _ := ToUuid(mockAccountId)
teamName := "team-name"
var nilAccounts []db_queries.NeosyncApiAccount
ctx := context.Background()

service := New(dbtxMock, querierMock)

dbtxMock.On("Begin", ctx).Return(mockTx, nil)
querierMock.On("GetTeamAccountsByUserId", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{}, nil)
querierMock.On("CreateTeamAccount", ctx, mockTx, teamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid}, nil)
querierMock.On("CreateAccountUserAssociation", ctx, mockTx).Return(nil, nil)
querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return(nilAccounts, sql.ErrNoRows)
querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil)
querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{
AccountID: accountUuid,
UserID: userUuid,
}).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil)
mockTx.On("Commit", ctx).Return(nil)
mockTx.On("Rollback", ctx).Return(nil)

resp, err := service.CreateTeamAccount(context.Background(), userUuid, teamName)
resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName)

assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, accountUuid, resp.ID)
assert.Equal(t, mockTeamName, resp.AccountSlug)
}

func Test_CreateTeamAccount_Rollback(t *testing.T) {
dbtxMock := NewMockDBTX(t)
querierMock := db_queries.NewMockQuerier(t)
mockTx := new(MockTx)

userUuid, _ := ToUuid(mockUserId)
accountUuid, _ := ToUuid(mockAccountId)
ctx := context.Background()
var nilAssociation db_queries.NeosyncApiAccountUserAssociation

service := New(dbtxMock, querierMock)

dbtxMock.On("Begin", ctx).Return(mockTx, nil)
querierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil)
querierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil)
querierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{
AccountID: accountUuid,
UserID: userUuid,
}).Return(nilAssociation, errors.New("sad"))
mockTx.On("Rollback", ctx).Return(nil)

resp, err := service.CreateTeamAccount(context.Background(), userUuid, mockTeamName)

mockTx.AssertCalled(t, "Rollback", ctx)
mockTx.AssertNotCalled(t, "Commit", mock.Anything)
assert.Error(t, err)
assert.Nil(t, resp)
}
28 changes: 28 additions & 0 deletions backend/services/mgmt/v1alpha1/user-account-service/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,34 @@ func Test_IsUserInAccount_False(t *testing.T) {
assert.Equal(t, false, resp.Msg.Ok)
}

func Test_CreateTeamAccount(t *testing.T) {
m := createServiceMock(t, &Config{IsAuthEnabled: true})
mockTx := new(nucleusdb.MockTx)

mockTeamName := "team-name"
ctx := getAuthenticatedCtxMock(mockAuthProvider)
userAssociation := getUserIdentityProviderAssociationMock(mockUserId, mockAuthProvider)
accountUuid, _ := nucleusdb.ToUuid(mockAccountId)
userUuid, _ := nucleusdb.ToUuid(mockUserId)
m.QuerierMock.On("GetUserAssociationByAuth0Id", ctx, mock.Anything, mockAuthProvider).Return(userAssociation, nil)
m.DbtxMock.On("Begin", ctx).Return(mockTx, nil)
m.QuerierMock.On("GetAccountsByUser", ctx, mockTx, userUuid).Return([]db_queries.NeosyncApiAccount{{AccountSlug: "other"}}, nil)
m.QuerierMock.On("CreateTeamAccount", ctx, mockTx, mockTeamName).Return(db_queries.NeosyncApiAccount{ID: accountUuid, AccountSlug: mockTeamName}, nil)
m.QuerierMock.On("CreateAccountUserAssociation", ctx, mockTx, db_queries.CreateAccountUserAssociationParams{
AccountID: accountUuid,
UserID: userUuid,
}).Return(db_queries.NeosyncApiAccountUserAssociation{}, nil)
mockTx.On("Commit", ctx).Return(nil)
mockTx.On("Rollback", ctx).Return(nil)

resp, err := m.Service.CreateTeamAccount(ctx, &connect.Request[mgmtv1alpha1.CreateTeamAccountRequest]{Msg: &mgmtv1alpha1.CreateTeamAccountRequest{Name: mockTeamName}})

assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, mockAccountId, resp.Msg.AccountId)

}

type serviceMocks struct {
Service *Service
DbtxMock *nucleusdb.MockDBTX
Expand Down

0 comments on commit 6f90805

Please sign in to comment.