diff --git a/internal/openfga/client.go b/internal/openfga/client.go index e12d1a7bc..7d8ccc63d 100644 --- a/internal/openfga/client.go +++ b/internal/openfga/client.go @@ -22,14 +22,14 @@ import ( ) type Client struct { - c OpenFGAClientInterface + c OpenFGACoreClientInterface tracer tracing.TracingInterface monitor monitoring.MonitorInterface logger logging.LoggerInterface } -func (c *Client) APIClient() OpenFGAClientInterface { +func (c *Client) APIClient() OpenFGACoreClientInterface { return c.c } diff --git a/internal/openfga/client_test.go b/internal/openfga/client_test.go index aa74bdcb1..5133cd130 100644 --- a/internal/openfga/client_test.go +++ b/internal/openfga/client_test.go @@ -39,7 +39,7 @@ func TestNewClientAPIClientImplementsInterface(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) specs := new(EnvSpec) @@ -63,9 +63,9 @@ func TestNewClientAPIClientImplementsInterface(t *testing.T) { c.c = mockOpenFGAClient if !reflect.TypeOf(c.APIClient()).Implements( - reflect.TypeOf((*OpenFGAClientInterface)(nil)).Elem(), + reflect.TypeOf((*OpenFGACoreClientInterface)(nil)).Elem(), ) { - t.Fatal("APIClient doesn't implement interface OpenFGAClientInterface") + t.Fatal("APIClient doesn't implement interface OpenFGACoreClientInterface") } } @@ -105,7 +105,7 @@ func TestClientListObjectsSuccess(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientListObjectsRequestInterface(ctrl) c := Client{ @@ -149,7 +149,7 @@ func TestClientListObjectsFails(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientListObjectsRequestInterface(ctrl) c := Client{ @@ -230,7 +230,7 @@ func TestClientReadTuplesSuccess(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientReadRequestInterface(ctrl) c := Client{ @@ -275,7 +275,7 @@ func TestClientReadTuplesFails(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientReadRequestInterface(ctrl) c := Client{ @@ -343,7 +343,7 @@ func TestClientWriteTuplesSuccess(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientWriteRequestInterface(ctrl) c := Client{ @@ -384,7 +384,7 @@ func TestClientWriteTuplesFails(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientWriteRequestInterface(ctrl) c := Client{ @@ -440,7 +440,7 @@ func TestClientDeleteTuplesSuccess(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientWriteRequestInterface(ctrl) c := Client{ @@ -481,7 +481,7 @@ func TestClientDeleteTuplesFails(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientWriteRequestInterface(ctrl) c := Client{ @@ -565,7 +565,7 @@ func TestClientWriteBatchCheckSuccess(t *testing.T) { mockLogger := NewMockLoggerInterface(ctrl) mockTracer := NewMockTracer(ctrl) mockMonitor := monitoring.NewMockMonitorInterface(ctrl) - mockOpenFGAClient := NewMockOpenFGAClientInterface(ctrl) + mockOpenFGAClient := NewMockOpenFGACoreClientInterface(ctrl) mockRequest := NewMockSdkClientBatchCheckRequestInterface(ctrl) c := Client{ diff --git a/internal/openfga/interfaces.go b/internal/openfga/interfaces.go index 66c5eca8d..b5f7bfd5d 100644 --- a/internal/openfga/interfaces.go +++ b/internal/openfga/interfaces.go @@ -9,7 +9,7 @@ import ( "github.com/openfga/go-sdk/client" ) -type OpenFGAClientInterface interface { +type OpenFGACoreClientInterface interface { GetAuthorizationModelId() (string, error) CreateStore(context.Context) client.SdkClientCreateStoreRequestInterface CreateStoreExecute(client.SdkClientCreateStoreRequestInterface) (*client.ClientCreateStoreResponse, error) @@ -30,3 +30,12 @@ type OpenFGAClientInterface interface { ListObjects(context.Context) client.SdkClientListObjectsRequestInterface ListObjectsExecute(client.SdkClientListObjectsRequestInterface) (*client.ClientListObjectsResponse, error) } + +// OpenFGAClientInterface is the interface used to decouple the OpenFGA store implementation +type OpenFGAClientInterface interface { + ListObjects(context.Context, string, string, string) ([]string, error) + ReadTuples(context.Context, string, string, string, string) (*client.ClientReadResponse, error) + WriteTuples(context.Context, ...Tuple) error + DeleteTuples(context.Context, ...Tuple) error + Check(context.Context, string, string, string, ...Tuple) (bool, error) +} diff --git a/internal/openfga/stores.go b/internal/openfga/stores.go new file mode 100644 index 000000000..aad2c6db0 --- /dev/null +++ b/internal/openfga/stores.go @@ -0,0 +1,344 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL-3.0 + +package openfga + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/canonical/identity-platform-admin-ui/internal/logging" + "github.com/canonical/identity-platform-admin-ui/internal/monitoring" + "github.com/canonical/identity-platform-admin-ui/internal/pool" + + trace "go.opentelemetry.io/otel/trace" +) + +const ( + ASSIGNEE_RELATION = "assignee" + MEMBER_RELATION = "member" + CAN_VIEW_RELATION = "can_view" +) + +// TODO @shipperizer this is internal material, worth reusing it across the board +// OpenFGAStore is an overarching store object to deal with OpenFGA entities, meant as a low level +// object to perform cross cutting logic only relevant to the application, therefore doesn't deal with +// user interpolations or returns fancy objects, that is offloaded to the service layer favouring reusability +type OpenFGAStore struct { + ofga OpenFGAClientInterface + + wpool pool.WorkerPoolInterface + + tracer trace.Tracer + monitor monitoring.MonitorInterface + logger logging.LoggerInterface +} + +// ListViewableRoles returns all the roles a specific "assignee"able resource (user, group#member, role#assignee) is linked to (using "can_view" OpenFGA relation) +func (s *OpenFGAStore) ListViewableRoles(ctx context.Context, ID string) ([]string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.ListViewableRoles") + defer span.End() + + roles, err := s.ofga.ListObjects(ctx, ID, CAN_VIEW_RELATION, "role") + + if err != nil { + s.logger.Error(err.Error()) + return nil, err + } + + return roles, nil +} + +// ListAssignedRoles returns all the roles a specific "assignee"able resource (user, group#member, role#assignee) is linked to (using "assignee" OpenFGA relation) +func (s *OpenFGAStore) ListAssignedRoles(ctx context.Context, assigneeID string) ([]string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.ListAssignedRoles") + defer span.End() + + roles, err := s.ofga.ListObjects(ctx, assigneeID, ASSIGNEE_RELATION, "role") + + if err != nil { + s.logger.Error(err.Error()) + return nil, err + } + + return roles, nil +} + +// ListAssignedGroups returns all the groups a specific user is memeber of (using "member" OpenFGA relation) +func (s *OpenFGAStore) ListAssignedGroups(ctx context.Context, assigneeID string) ([]string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.ListAssignedGroups") + defer span.End() + + groups, err := s.ofga.ListObjects(ctx, assigneeID, MEMBER_RELATION, "group") + + if err != nil { + s.logger.Error(err.Error()) + return nil, err + } + + return groups, nil +} + +// AssignRoles assigns roles to an "assignee"able resource (user, group#member) +func (s *OpenFGAStore) AssignRoles(ctx context.Context, assigneeID string, roleIDs ...string) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.AssignRoles") + defer span.End() + + // preemptive check to verify if all roles to be assigned are accessible by the user + // needs to happen separately + + rs := make([]Tuple, 0) + + for _, roleID := range roleIDs { + rs = append(rs, *NewTuple(assigneeID, ASSIGNEE_RELATION, roleID)) + } + + err := s.ofga.WriteTuples(ctx, rs...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// UnassignRoles drops roles from an "assignee"able resource (user, group#member) +func (s *OpenFGAStore) UnassignRoles(ctx context.Context, assigneeID string, roleIDs ...string) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.UnassignRoles") + defer span.End() + + // preemptive check to verify if all roles to be assigned are accessible by the user + // needs to happen separately + + rs := make([]Tuple, 0) + + for _, roleID := range roleIDs { + rs = append(rs, *NewTuple(assigneeID, ASSIGNEE_RELATION, roleID)) + } + + err := s.ofga.DeleteTuples(ctx, rs...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// AssignGroups assigns groups to an "assignee"able resource (user, group#member) +func (s *OpenFGAStore) AssignGroups(ctx context.Context, assigneeID string, groupIDs ...string) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.AssignGroups") + defer span.End() + + // preemptive check to verify if all Groups to be assigned are accessible by the user + // needs to happen separately + + rs := make([]Tuple, 0) + + for _, groupID := range groupIDs { + rs = append(rs, *NewTuple(assigneeID, MEMBER_RELATION, groupID)) + } + + err := s.ofga.WriteTuples(ctx, rs...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// UnassignGroups drops Groups from an "assignee"able resource (user, group#member) +func (s *OpenFGAStore) UnassignGroups(ctx context.Context, assigneeID string, groupIDs ...string) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.UnassignGroups") + defer span.End() + + // preemptive check to verify if all Groups to be assigned are accessible by the user + // needs to happen separately + + rs := make([]Tuple, 0) + + for _, groupID := range groupIDs { + rs = append(rs, *NewTuple(assigneeID, MEMBER_RELATION, groupID)) + } + + err := s.ofga.DeleteTuples(ctx, rs...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// AssignPermissions assigns permissions to an "assignee"able resource (user, group#member, role#assignee) +func (s *OpenFGAStore) AssignPermissions(ctx context.Context, assigneeID string, permissions ...Permission) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.AssignPermissions") + defer span.End() + + // preemptive check to verify if all permissions to be assigned are accessible by the user + // needs to happen separately + + ps := make([]Tuple, 0) + + for _, p := range permissions { + ps = append(ps, *NewTuple(assigneeID, p.Relation, p.Object)) + } + + err := s.ofga.WriteTuples(ctx, ps...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// UnassignPermissions removes permissions from an "assignee"able resource (user, group#member, role#assignee) +func (s *OpenFGAStore) UnassignPermissions(ctx context.Context, assigneeID string, permissions ...Permission) error { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.UnassignPermissions") + defer span.End() + + // preemptive check to verify if all permissions to be assigned are accessible by the user + // needs to happen separately + + ps := make([]Tuple, 0) + + for _, p := range permissions { + ps = append(ps, *NewTuple(assigneeID, p.Relation, p.Object)) + } + + err := s.ofga.DeleteTuples(ctx, ps...) + + if err != nil { + s.logger.Error(err.Error()) + return err + } + + return nil +} + +// ListPermissions returns all the permissions associated to a specific entity +func (s *OpenFGAStore) ListPermissions(ctx context.Context, ID string, continuationTokens map[string]string) ([]Permission, map[string]string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.ListPermissions") + defer span.End() + + // keep it a buffered channel, if set to unbuffered we would need a goroutine + // to consume from it before pushing to it + // https://go.dev/ref/spec#Send_statements + // A send on an unbuffered channel can proceed if a receiver is ready. + // A send on a buffered channel can proceed if there is room in the buffer + results := make(chan *pool.Result[any], len(s.permissionTypes())) + + wg := sync.WaitGroup{} + wg.Add(len(s.permissionTypes())) + + for _, t := range s.permissionTypes() { + s.wpool.Submit( + s.listPermissionsFunc(ctx, ID, t, continuationTokens[t]), + results, + &wg, + ) + } + + // wait for tasks to finish + wg.Wait() + + // close result channel + close(results) + + permissions := make([]Permission, 0) + tMap := make(map[string]string) + errors := make([]error, 0) + + for r := range results { + v := r.Value.(listPermissionsResult) + permissions = append(permissions, v.permissions...) + tMap[v.ofgaType] = v.token + + if v.err != nil { + errors = append(errors, v.err) + } + } + + if len(errors) == 0 { + return permissions, tMap, nil + } + + eMsg := "" + + for n, e := range errors { + s.logger.Errorf(e.Error()) + eMsg = fmt.Sprintf("%s%v - %s\n", eMsg, n, e.Error()) + } + + return permissions, tMap, fmt.Errorf(eMsg) +} + +func (s *OpenFGAStore) listPermissionsFunc(ctx context.Context, ID, ofgaType, cToken string) func() any { + return func() any { + p, token, err := s.listPermissionsByType( + ctx, + ID, + ofgaType, + cToken, + ) + + return listPermissionsResult{ + permissions: p, + ofgaType: ofgaType, + token: token, + err: err, + } + } +} + +func (s *OpenFGAStore) listPermissionsByType(ctx context.Context, ID, pType, continuationToken string) ([]Permission, string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.listPermissionsByType") + defer span.End() + + r, err := s.ofga.ReadTuples(ctx, ID, "", fmt.Sprintf("%s:", pType), continuationToken) + + if err != nil { + s.logger.Error(err.Error()) + return nil, "", err + } + + permissions := make([]Permission, 0) + + for _, t := range r.GetTuples() { + // if relation doesn't start with can_ it means it's not a permission (see #assignee) + if !strings.HasPrefix(t.Key.Relation, "can_") { + continue + } + + permissions = append(permissions, Permission{Relation: t.Key.Relation, Object: t.Key.Object}) + } + + return permissions, r.GetContinuationToken(), nil +} + +func (s *OpenFGAStore) permissionTypes() []string { + return []string{"group", "role", "identity", "scheme", "provider", "client"} +} + +// NewOpenFGAStore returns the implementation of the store +func NewOpenFGAStore(ofga OpenFGAClientInterface, wpool pool.WorkerPoolInterface, tracer trace.Tracer, monitor monitoring.MonitorInterface, logger logging.LoggerInterface) *OpenFGAStore { + s := new(OpenFGAStore) + + s.ofga = ofga + s.wpool = wpool + + s.monitor = monitor + s.tracer = tracer + s.logger = logger + + return s +} diff --git a/internal/openfga/stores_test.go b/internal/openfga/stores_test.go new file mode 100644 index 000000000..a5e0b2d90 --- /dev/null +++ b/internal/openfga/stores_test.go @@ -0,0 +1,963 @@ +// Copyright 2024 Canonical Ltd. +// SPDX-License-Identifier: AGPL + +package openfga + +import ( + "cmp" + "context" + "fmt" + "reflect" + "slices" + "strings" + sync "sync" + "testing" + "time" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/mock/gomock" + + "github.com/google/uuid" + openfga "github.com/openfga/go-sdk" + "github.com/openfga/go-sdk/client" + + "github.com/canonical/identity-platform-admin-ui/internal/monitoring" + pool "github.com/canonical/identity-platform-admin-ui/internal/pool" +) + +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_logger.go -source=../../internal/logging/interfaces.go +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_client.go -source=./interfaces.go +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_openfga_client.go github.com/openfga/go-sdk/client SdkClientListObjectsRequestInterface,SdkClientReadRequestInterface,SdkClientWriteRequestInterface,SdkClientBatchCheckRequestInterface +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_monitor.go -source=../../internal/monitoring/interfaces.go +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_pool.go -source=../../internal/pool/interfaces.go +//go:generate mockgen -build_flags=--mod=mod -package openfga -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer + +func setupMockSubmit(wp *MockWorkerPoolInterface, resultsChan chan *pool.Result[any]) (*gomock.Call, chan *pool.Result[any]) { + key := uuid.New() + var internalResultsChannel chan *pool.Result[any] + + call := wp.EXPECT().Submit(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Do( + func(command any, results chan *pool.Result[any], wg *sync.WaitGroup) { + var value any = true + + switch commandFunc := command.(type) { + case func(): + commandFunc() + case func() any: + value = commandFunc() + } + + result := pool.NewResult[any](key, value) + results <- result + if resultsChan != nil { + resultsChan <- result + } + + wg.Done() + + internalResultsChannel = results + }, + ).Return(key.String(), nil) + + return call, internalResultsChannel +} + +func TestStoreListViewableRoles(t *testing.T) { + type expected struct { + err error + roles []string + } + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "empty result", + input: "user:joe", + expected: expected{ + roles: []string{}, + err: nil, + }, + }, + { + name: "error", + input: "role:administrator#assignee", + expected: expected{ + roles: []string{}, + err: fmt.Errorf("error"), + }, + }, + { + name: "full result", + input: "group:is#member", + expected: expected{ + roles: []string{"global", "administrator", "viewer"}, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().ListObjects(gomock.Any(), test.input, "can_view", "role").Return(test.expected.roles, test.expected.err) + + if test.expected.err != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + roles, err := store.ListViewableRoles(context.Background(), test.input) + + if err != test.expected.err { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err == nil && !reflect.DeepEqual(roles, test.expected.roles) { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.roles, roles) + } + }) + } +} + +func TestStoreListAssignedRoles(t *testing.T) { + type expected struct { + err error + roles []string + } + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "empty result", + input: "user:joe", + expected: expected{ + roles: []string{}, + err: nil, + }, + }, + { + name: "error", + input: "user:joe", + expected: expected{ + roles: []string{}, + err: fmt.Errorf("error"), + }, + }, + { + name: "full result", + input: "group:is#member", + expected: expected{ + roles: []string{"global", "administrator", "viewer"}, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().ListObjects(gomock.Any(), test.input, ASSIGNEE_RELATION, "role").Return(test.expected.roles, test.expected.err) + + if test.expected.err != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + roles, err := store.ListAssignedRoles(context.Background(), test.input) + + if err != test.expected.err { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err == nil && !reflect.DeepEqual(roles, test.expected.roles) { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.roles, roles) + } + }) + } +} + +func TestStoreListAssignedGroups(t *testing.T) { + type expected struct { + err error + groups []string + } + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "empty result", + input: "user:joe", + expected: expected{ + groups: []string{}, + err: nil, + }, + }, + { + name: "error", + input: "user:joe", + expected: expected{ + groups: []string{}, + err: fmt.Errorf("error"), + }, + }, + { + name: "full result", + input: "group:is#member", + expected: expected{ + groups: []string{"global", "administrator", "viewer"}, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().ListObjects(gomock.Any(), test.input, MEMBER_RELATION, "group").Return(test.expected.groups, test.expected.err) + + if test.expected.err != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + groups, err := store.ListAssignedGroups(context.Background(), test.input) + + if err != test.expected.err { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err == nil && !reflect.DeepEqual(groups, test.expected.groups) { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.groups, groups) + } + }) + } +} + +func TestStoreAssignRoles(t *testing.T) { + type input struct { + assignee string + roles []string + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "group:administrator#member", + roles: []string{"role:viewer"}, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple roles to group members", + input: input{ + assignee: "group:administrator#member", + roles: []string{"role:viewer", "role:writer", "role:super"}, + }, + expected: nil, + }, + { + name: "multiple roles to a user", + input: input{ + assignee: "user:joe", + roles: []string{"role:viewer", "role:writer", "role:super"}, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().WriteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + roles := make([]Tuple, 0) + + for _, role := range test.input.roles { + roles = append(roles, *NewTuple(test.input.assignee, ASSIGNEE_RELATION, role)) + } + + if !reflect.DeepEqual(roles, tuples) { + t.Errorf("expected tuples to be %v got %v", roles, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.AssignRoles(context.Background(), test.input.assignee, test.input.roles...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreUnassignRoles(t *testing.T) { + type input struct { + assignee string + roles []string + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "group:administrator#member", + roles: []string{"role:viewer"}, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple roles to group members", + input: input{ + assignee: "group:administrator#member", + roles: []string{"role:viewer", "role:writer", "role:super"}, + }, + expected: nil, + }, + { + name: "multiple roles to a user", + input: input{ + assignee: "user:joe", + roles: []string{"role:viewer", "role:writer", "role:super"}, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().DeleteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + roles := make([]Tuple, 0) + + for _, role := range test.input.roles { + roles = append(roles, *NewTuple(test.input.assignee, ASSIGNEE_RELATION, role)) + } + + if !reflect.DeepEqual(roles, tuples) { + t.Errorf("expected tuples to be %v got %v", roles, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.UnassignRoles(context.Background(), test.input.assignee, test.input.roles...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreAssignGroups(t *testing.T) { + type input struct { + assignee string + groups []string + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "administrator", + groups: []string{"group:viewer"}, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple groups to group members", + input: input{ + assignee: "group:administrator#member", + groups: []string{"group:viewer", "group:writer", "group:super"}, + }, + expected: nil, + }, + { + name: "multiple groups to a user", + input: input{ + assignee: "user:joe", + groups: []string{"group:viewer", "group:writer", "group:super"}, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().WriteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + groups := make([]Tuple, 0) + + for _, group := range test.input.groups { + groups = append(groups, *NewTuple(test.input.assignee, MEMBER_RELATION, group)) + } + + if !reflect.DeepEqual(groups, tuples) { + t.Errorf("expected tuples to be %v got %v", groups, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.AssignGroups(context.Background(), test.input.assignee, test.input.groups...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreUnassignGroups(t *testing.T) { + type input struct { + assignee string + groups []string + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "administrator", + groups: []string{"group:viewer"}, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple groups to group members", + input: input{ + assignee: "group:administrator#member", + groups: []string{"group:viewer", "group:writer", "group:super"}, + }, + expected: nil, + }, + { + name: "multiple groups to a user", + input: input{ + assignee: "user:joe", + groups: []string{"group:viewer", "group:writer", "group:super"}, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().DeleteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + groups := make([]Tuple, 0) + + for _, group := range test.input.groups { + groups = append(groups, *NewTuple(test.input.assignee, MEMBER_RELATION, group)) + } + + if !reflect.DeepEqual(groups, tuples) { + t.Errorf("expected tuples to be %v got %v", groups, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.UnassignGroups(context.Background(), test.input.assignee, test.input.groups...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreAssignPermissions(t *testing.T) { + type input struct { + assignee string + permissions []Permission + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "role:administrator#assignee", + permissions: []Permission{ + {Relation: "can_delete", Object: "role:admin"}, + }, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple permissions to role", + input: input{ + assignee: "role:administrator#assignee", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + { + name: "multiple permissions to group", + input: input{ + assignee: "group:administrator#member", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + { + name: "multiple permissions to user", + input: input{ + assignee: "user:joe", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().WriteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + ps := make([]Tuple, 0) + + for _, p := range test.input.permissions { + ps = append(ps, *NewTuple(test.input.assignee, p.Relation, p.Object)) + } + + if !reflect.DeepEqual(ps, tuples) { + t.Errorf("expected tuples to be %v got %v", ps, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.AssignPermissions(context.Background(), test.input.assignee, test.input.permissions...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreUnassignPermissions(t *testing.T) { + type input struct { + assignee string + permissions []Permission + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + assignee: "role:administrator#assignee", + permissions: []Permission{ + {Relation: "can_delete", Object: "role:admin"}, + }, + }, + expected: fmt.Errorf("error"), + }, + { + name: "multiple permissions to role", + input: input{ + assignee: "role:administrator#assignee", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + { + name: "multiple permissions to group", + input: input{ + assignee: "group:administrator#member", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + { + name: "multiple permissions to user", + input: input{ + assignee: "user:joe", + permissions: []Permission{ + {Relation: "can_view", Object: "client:okta"}, + {Relation: "can_edit", Object: "client:okta"}, + {Relation: "can_delete", Object: "group:admin"}, + }, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockOpenFGA.EXPECT().DeleteTuples(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, tuples ...Tuple) error { + ps := make([]Tuple, 0) + + for _, p := range test.input.permissions { + ps = append(ps, *NewTuple(test.input.assignee, p.Relation, p.Object)) + } + + if !reflect.DeepEqual(ps, tuples) { + t.Errorf("expected tuples to be %v got %v", ps, tuples) + } + + return test.expected + }, + ) + + if test.expected != nil { + mockLogger.EXPECT().Error(gomock.Any()).Times(1) + } + + err := store.UnassignPermissions(context.Background(), test.input.assignee, test.input.permissions...) + + if err != test.expected { + t.Errorf("expected error to be %v got %v", test.expected, err) + } + }) + } +} + +func TestStoreListPermissions(t *testing.T) { + type input struct { + ID string + cTokens map[string]string + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + ID: "role:administrator#assignee", + }, + expected: fmt.Errorf("error"), + }, + { + name: "role found", + input: input{ + ID: "role:administrator#assignee", + cTokens: map[string]string{ + "role": "test", + }, + }, + expected: nil, + }, + { + name: "group found", + input: input{ + ID: "group:administrator#member", + cTokens: map[string]string{ + "role": "test", + }, + }, + expected: nil, + }, + { + name: "user found", + input: input{ + ID: "use:joe", + cTokens: map[string]string{ + "role": "test", + }, + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + for i := 0; i < 6; i++ { + setupMockSubmit(mockWorkerPool, nil) + } + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + + expCTokens := map[string]string{ + "role": "", + "group": "", + "identity": "", + "scheme": "", + "provider": "", + "client": "", + } + + expPermissions := []Permission{ + Permission{Relation: "can_edit", Object: "role:test"}, + Permission{Relation: "can_edit", Object: "group:test"}, + Permission{Relation: "can_edit", Object: "identity:test"}, + Permission{Relation: "can_edit", Object: "scheme:test"}, + Permission{Relation: "can_edit", Object: "provider:test"}, + Permission{Relation: "can_edit", Object: "client:test"}, + } + + calls := []*gomock.Call{} + + for _, _ = range store.permissionTypes() { + calls = append( + calls, + mockOpenFGA.EXPECT().ReadTuples(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, user, relation, object, continuationToken string) (*client.ClientReadResponse, error) { + if test.expected != nil { + return nil, test.expected + } + + if user != test.input.ID { + t.Errorf("wrong user parameter expected %s got %s", test.input.ID, user) + } + + if object == "role:" && continuationToken != "test" { + t.Errorf("missing continuation token %s", test.input.cTokens["roles"]) + } + + tuples := []openfga.Tuple{ + *openfga.NewTuple( + *openfga.NewTupleKey( + user, "can_edit", fmt.Sprintf("%stest", object), + ), + time.Now(), + ), + *openfga.NewTuple( + *openfga.NewTupleKey( + user, "assignee", "role:test", + ), + time.Now(), + ), + } + + r := new(client.ClientReadResponse) + r.SetContinuationToken("") + r.SetTuples(tuples) + + return r, nil + }, + ), + ) + } + + if test.expected != nil { + // TODO @shipperizer fix this so that we can pin it down to the error case only + mockLogger.EXPECT().Error(gomock.Any()).MinTimes(0).MaxTimes(12) + mockLogger.EXPECT().Errorf(gomock.Any()).MaxTimes(12) + } + + gomock.InAnyOrder(calls) + permissions, cTokens, err := store.ListPermissions(context.Background(), test.input.ID, test.input.cTokens) + + if err != nil && test.expected == nil { + t.Fatalf("expected error to be silenced and return nil got %v instead", err) + } + + sortFx := func(a, b Permission) int { + if n := strings.Compare(a.Relation, b.Relation); n != 0 { + return n + } + // If relations are equal, order by object + return cmp.Compare(a.Object, b.Object) + } + + slices.SortFunc(permissions, sortFx) + slices.SortFunc(expPermissions, sortFx) + + if err == nil && test.expected == nil && !reflect.DeepEqual(permissions, expPermissions) { + t.Fatalf("expected permissions to be %v got %v", expPermissions, permissions) + } + + if err == nil && test.expected == nil && !reflect.DeepEqual(cTokens, expCTokens) { + t.Fatalf("expected continuation tokens to be %v got %v", expCTokens, cTokens) + } + }) + } +} diff --git a/internal/openfga/types.go b/internal/openfga/types.go index 4a1ab40ad..9cb1873f1 100644 --- a/internal/openfga/types.go +++ b/internal/openfga/types.go @@ -3,6 +3,18 @@ package openfga +type listPermissionsResult struct { + permissions []Permission + token string + ofgaType string + err error +} + +type Permission struct { + Relation string `json:"relation" validate:"required"` + Object string `json:"object" validate:"required"` +} + // Tuple is simply a wrapper around openfga TupleKey // reason to have it is to hide underlying library complexity // in case we want to swap it diff --git a/pkg/identities/interfaces.go b/pkg/identities/interfaces.go index c10b53b90..5a4df59ec 100644 --- a/pkg/identities/interfaces.go +++ b/pkg/identities/interfaces.go @@ -7,6 +7,8 @@ import ( "context" kClient "github.com/ory/kratos-client-go" + + ofga "github.com/canonical/identity-platform-admin-ui/internal/openfga" ) type AuthorizerInterface interface { @@ -21,3 +23,15 @@ type ServiceInterface interface { UpdateIdentity(context.Context, string, *kClient.UpdateIdentityBody) (*IdentityData, error) DeleteIdentity(context.Context, string) (*IdentityData, error) } + +type OpenFGAStoreInterface interface { + ListAssignedRoles(context.Context, string) ([]string, error) + ListAssignedGroups(context.Context, string) ([]string, error) + AssignRoles(context.Context, string, ...string) error + UnassignRoles(context.Context, string, ...string) error + AssignGroups(context.Context, string, ...string) error + UnassignGroups(context.Context, string, ...string) error + ListPermissions(context.Context, string, map[string]string) ([]ofga.Permission, map[string]string, error) + AssignPermissions(context.Context, string, ...ofga.Permission) error + UnassignPermissions(context.Context, string, ...ofga.Permission) error +} diff --git a/pkg/identities/service.go b/pkg/identities/service.go index df01f4cf9..22cb723cc 100644 --- a/pkg/identities/service.go +++ b/pkg/identities/service.go @@ -9,15 +9,24 @@ import ( "fmt" "io" "net/http" + "strings" + v1 "github.com/canonical/rebac-admin-ui-handlers/v1" + "github.com/canonical/rebac-admin-ui-handlers/v1/resources" kClient "github.com/ory/kratos-client-go" "go.opentelemetry.io/otel/trace" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + coreV1 "k8s.io/client-go/kubernetes/typed/core/v1" "github.com/canonical/identity-platform-admin-ui/internal/http/types" "github.com/canonical/identity-platform-admin-ui/internal/logging" "github.com/canonical/identity-platform-admin-ui/internal/monitoring" + ofga "github.com/canonical/identity-platform-admin-ui/internal/openfga" ) +// TODO @shipperizer unify this value with schemas/service.go +const DEFAULT_SCHEMA = "default.schema" + type Service struct { kratos kClient.IdentityAPI authz AuthorizerInterface @@ -63,7 +72,7 @@ func (s *Service) parseError(r *http.Response) *kClient.GenericError { } func (s *Service) ListIdentities(ctx context.Context, size int64, token, credID string) (*IdentityData, error) { - ctx, span := s.tracer.Start(ctx, "kratos.IdentityAPI.ListIdentities") + ctx, span := s.tracer.Start(ctx, "identities.Service.ListIdentities") defer span.End() identities, rr, err := s.kratos.ListIdentitiesExecute( @@ -94,7 +103,7 @@ func (s *Service) ListIdentities(ctx context.Context, size int64, token, credID } func (s *Service) GetIdentity(ctx context.Context, ID string) (*IdentityData, error) { - ctx, span := s.tracer.Start(ctx, "kratos.IdentityAPI.GetIdentity") + ctx, span := s.tracer.Start(ctx, "identities.Service.GetIdentity") defer span.End() identity, rr, err := s.kratos.GetIdentityExecute( @@ -118,7 +127,7 @@ func (s *Service) GetIdentity(ctx context.Context, ID string) (*IdentityData, er } func (s *Service) CreateIdentity(ctx context.Context, bodyID *kClient.CreateIdentityBody) (*IdentityData, error) { - ctx, span := s.tracer.Start(ctx, "kratos.IdentityAPI.CreateIdentity") + ctx, span := s.tracer.Start(ctx, "identities.Service.CreateIdentity") defer span.End() if bodyID == nil { @@ -158,7 +167,7 @@ func (s *Service) CreateIdentity(ctx context.Context, bodyID *kClient.CreateIden } func (s *Service) UpdateIdentity(ctx context.Context, ID string, bodyID *kClient.UpdateIdentityBody) (*IdentityData, error) { - ctx, span := s.tracer.Start(ctx, "kratos.IdentityAPI.UpdateIdentity") + ctx, span := s.tracer.Start(ctx, "identities.Service.UpdateIdentity") defer span.End() if ID == "" { err := fmt.Errorf("no identity ID passed") @@ -207,7 +216,7 @@ func (s *Service) UpdateIdentity(ctx context.Context, ID string, bodyID *kClient } func (s *Service) DeleteIdentity(ctx context.Context, ID string) (*IdentityData, error) { - ctx, span := s.tracer.Start(ctx, "kratos.IdentityAPI.DeleteIdentity") + ctx, span := s.tracer.Start(ctx, "identities.Service.DeleteIdentity") defer span.End() rr, err := s.kratos.DeleteIdentityExecute( @@ -240,3 +249,514 @@ func NewService(kratos kClient.IdentityAPI, authz AuthorizerInterface, tracer tr return s } + +type V1Service struct { + cmName string + cmNamespace string + + k8s coreV1.CoreV1Interface + openfgaStore OpenFGAStoreInterface + + core *Service +} + +func (s *V1Service) getDefaultSchema(ctx context.Context) (string, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.getDefaultSchema") + defer span.End() + + cm, err := s.k8s.ConfigMaps(s.cmNamespace).Get(ctx, s.cmName, metaV1.GetOptions{}) + + if err != nil { + s.core.logger.Error(err.Error()) + return "", err + } + + ID, ok := cm.Data[DEFAULT_SCHEMA] + + if !ok { + return "", fmt.Errorf("missing default schema") + } + + return ID, nil +} + +// ListIdentities returns a page of Identity objects of at least `size` elements if available +func (s *V1Service) ListIdentities(ctx context.Context, params *resources.GetIdentitiesParams) (*resources.PaginatedResponse[resources.Identity], error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.ListIdentities") + defer span.End() + + size := 100 + token := "" + + if params != nil && params.Size != nil { + size = *params.Size + } + + if params != nil && params.NextToken != nil { + token = *params.NextToken + } + + // TODO @shipperizer use params.Filter to fetch credID + ids, err := s.core.ListIdentities(ctx, int64(size), token, "") + + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + r := new(resources.PaginatedResponse[resources.Identity]) + r.Data = make([]resources.Identity, 0) + r.Meta = resources.ResponseMeta{Size: len(ids.Identities), PageToken: &token} + r.Next = resources.Next{PageToken: &ids.Tokens.Next} + for _, id := range ids.Identities { + traits, ok := id.Traits.(map[string]string) + + if !ok { + traits = make(map[string]string) + } + + // TODO @shipperizer enhance Identity resource with Permissions and Roles on the next iteration + // this requires calls to openfga in here unless we enhance the PrincipalContext and let that do + // the calls + i := resources.Identity{ + Id: &id.Id, + } + + if email, ok := traits["email"]; ok { + i.Email = email + } + + fullname, ok := traits["name"] + + if !ok { + r.Data = append(r.Data, i) + continue + } + + surnameIndex := strings.LastIndex(fullname, " ") + + if surnameIndex > 0 { + name := strings.Trim(fullname[0:surnameIndex], " ") + surname := strings.Trim(fullname[surnameIndex:], " ") + + i.FirstName = &name + i.LastName = &surname + } + + r.Data = append(r.Data, i) + } + + return r, nil +} + +// CreateIdentity creates a single Identity. +func (s *V1Service) CreateIdentity(ctx context.Context, identity *resources.Identity) (*resources.Identity, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.CreateIdentity") + defer span.End() + + active := "StateActive" + schemaId, err := s.getDefaultSchema(ctx) + + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + if identity == nil { + return nil, v1.NewRequestBodyValidationError("bad identity payload") + } + + traits := make(map[string]interface{}) + + traits["email"] = identity.Email + + if identity.FirstName != nil && identity.LastName != nil { + traits["name"] = fmt.Sprintf("%s %s", *identity.FirstName, *identity.LastName) + } + + ids, err := s.core.CreateIdentity(ctx, + &kClient.CreateIdentityBody{ + State: &active, + SchemaId: schemaId, + // TODO @shipperizer the code below assumes each schema has name and email + // needs to be validated as schemas might differ + Traits: traits, + }, + ) + + // TODO @shipperizer enhance Identity resource with Permissions and Roles on the next iteration + // this requires calls to openfga in here unless we enhance the PrincipalContext and let that do + // the calls + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + if len(ids.Identities) != 1 { + return nil, v1.NewInvalidRequestError("no identity created") + } + + return &resources.Identity{ + Email: identity.Email, + FirstName: identity.FirstName, + LastName: identity.LastName, + Id: &ids.Identities[0].Id, + }, nil +} + +// GetIdentity returns a single Identity. +func (s *V1Service) GetIdentity(ctx context.Context, identityId string) (*resources.Identity, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.GetIdentity") + defer span.End() + + ids, err := s.core.GetIdentity(ctx, identityId) + + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + if ids.Identities == nil || len(ids.Identities) != 1 { + return nil, v1.NewNotFoundError("identity not found") + } + + id := ids.Identities[0] + + traits, ok := id.Traits.(map[string]string) + + if !ok { + traits = make(map[string]string) + } + + // TODO @shipperizer enhance Identity resource with Permissions and Roles on the next iteration + // this requires calls to openfga in here unless we enhance the PrincipalContext and let that do + // the calls + i := new(resources.Identity) + + i.Id = &id.Id + + if email, ok := traits["email"]; ok { + i.Email = email + } + + fullname, ok := traits["name"] + if !ok { + return i, nil + } + + surnameIndex := strings.LastIndex(fullname, " ") + + if surnameIndex > 0 { + name := strings.Trim(fullname[0:surnameIndex], " ") + surname := strings.Trim(fullname[surnameIndex:], " ") + + i.FirstName = &name + i.LastName = &surname + } + + return i, nil +} + +// UpdateIdentity updates an Identity. +func (s *V1Service) UpdateIdentity(ctx context.Context, identity *resources.Identity) (*resources.Identity, error) { + _, span := s.core.tracer.Start(ctx, "identities.V1Service.UpdateIdentity") + defer span.End() + + if identity == nil { + return nil, v1.NewRequestBodyValidationError("bad identity payload") + } + + traits := make(map[string]interface{}) + + traits["email"] = identity.Email + if identity.FirstName != nil && identity.LastName != nil { + traits["name"] = fmt.Sprintf("%s %s", *identity.FirstName, *identity.LastName) + } + + body := kClient.NewUpdateIdentityBodyWithDefaults() + body.SetTraits(traits) + + ids, err := s.core.UpdateIdentity( + ctx, + *identity.Id, + // TODO @shipperizer the code below assumes each schema has name and email + // needs to be validated as schemas might differ + body, + ) + + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + if len(ids.Identities) != 1 { + return nil, v1.NewInvalidRequestError("no identity created") + } + + id := ids.Identities[0] + + ts, ok := id.GetTraits().(map[string]string) + + if !ok { + ts = make(map[string]string) + } + + // TODO @shipperizer enhance Identity resource with Permissions and Roles on the next iteration + // this requires calls to openfga in here unless we enhance the PrincipalContext and let that do + // the calls + i := new(resources.Identity) + + i.Id = &id.Id + + if email, ok := ts["email"]; ok { + i.Email = email + } + + fullname, ok := ts["name"] + if !ok { + return i, nil + } + + surnameIndex := strings.LastIndex(fullname, " ") + + if surnameIndex > 0 { + name := strings.Trim(fullname[0:surnameIndex], " ") + surname := strings.Trim(fullname[surnameIndex:], " ") + + i.FirstName = &name + i.LastName = &surname + } + + return i, nil + +} + +// DeleteIdentity deletes an Identity +// returns (true, nil) in case an identity was successfully delete +// return (false, error) in case something went wrong +// implementors may want to return (false, nil) for idempotency cases +func (s *V1Service) DeleteIdentity(ctx context.Context, identityId string) (bool, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.DeleteIdentity") + defer span.End() + + if _, err := s.core.DeleteIdentity(ctx, identityId); err != nil { + return false, v1.NewUnknownError(err.Error()) + } + + return true, nil +} + +// GetIdentityGroups returns a page of Groups for identity `identityId`. +func (s *V1Service) GetIdentityGroups(ctx context.Context, identityId string, params *resources.GetIdentitiesItemGroupsParams) (*resources.PaginatedResponse[resources.Group], error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.GetIdentityGroups") + defer span.End() + + groups, err := s.openfgaStore.ListAssignedGroups(ctx, fmt.Sprintf("user:%s", identityId)) + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + r := new(resources.PaginatedResponse[resources.Group]) + r.Data = make([]resources.Group, 0) + r.Meta = resources.ResponseMeta{Size: len(groups)} + + for _, group := range groups { + r.Data = append(r.Data, resources.Group{Id: &group, Name: group}) + } + + return r, nil +} + +// GetIdentityRoles returns a page of Roles for identity `identityId`. +func (s *V1Service) GetIdentityRoles(ctx context.Context, identityId string, params *resources.GetIdentitiesItemRolesParams) (*resources.PaginatedResponse[resources.Role], error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.GetIdentityRoles") + defer span.End() + + roles, err := s.openfgaStore.ListAssignedRoles(ctx, fmt.Sprintf("user:%s", identityId)) + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + r := new(resources.PaginatedResponse[resources.Role]) + r.Data = make([]resources.Role, 0) + r.Meta = resources.ResponseMeta{Size: len(roles)} + + for _, role := range roles { + r.Data = append(r.Data, resources.Role{Id: &role, Name: role}) + } + + return r, nil +} + +// PatchIdentityGroups performs addition or removal of Groups to/from an Identity. +func (s *V1Service) PatchIdentityGroups(ctx context.Context, identityId string, groupPatches []resources.IdentityGroupsPatchItem) (bool, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.PatchIdentityGroups") + defer span.End() + + additions := make([]string, 0) + removals := make([]string, 0) + for _, p := range groupPatches { + group := fmt.Sprintf("group:%s", p.Group) + + if p.Op == "add" { + additions = append(additions, group) + } else if p.Op == "remove" { + removals = append(removals, group) + } + } + + if len(additions) > 0 { + err := s.openfgaStore.AssignGroups(ctx, fmt.Sprintf("user:%s", identityId), additions...) + + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + if len(removals) > 0 { + err := s.openfgaStore.UnassignGroups(ctx, fmt.Sprintf("user:%s", identityId), removals...) + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + return true, nil +} + +// PatchIdentityRoles performs addition or removal of Roles to/from an Identity. +func (s *V1Service) PatchIdentityRoles(ctx context.Context, identityId string, rolePatches []resources.IdentityRolesPatchItem) (bool, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.PatchIdentityRoles") + defer span.End() + + additions := make([]string, 0) + removals := make([]string, 0) + for _, p := range rolePatches { + role := fmt.Sprintf("role:%s", p.Role) + + if p.Op == "add" { + additions = append(additions, role) + } else if p.Op == "remove" { + removals = append(removals, role) + } + } + + if len(additions) > 0 { + err := s.openfgaStore.AssignRoles(ctx, fmt.Sprintf("user:%s", identityId), additions...) + + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + if len(removals) > 0 { + err := s.openfgaStore.UnassignRoles(ctx, fmt.Sprintf("user:%s", identityId), removals...) + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + return true, nil +} + +// GetIdentityEntitlements returns a page of Entitlements for identity `identityId`. +func (s *V1Service) GetIdentityEntitlements(ctx context.Context, identityId string, params *resources.GetIdentitiesItemEntitlementsParams) (*resources.PaginatedResponse[resources.EntityEntitlement], error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.GetIdentityEntitlements") + defer span.End() + + paginator := types.NewTokenPaginator(s.core.tracer, s.core.logger) + + nextToken := "" + + if params != nil && params.NextPageToken != nil { + nextToken = *params.NextPageToken + } + + if err := paginator.LoadFromString(ctx, nextToken); err != nil { + s.core.logger.Error(err) + } + + permissions, pageTokens, err := s.openfgaStore.ListPermissions(ctx, fmt.Sprintf("user:%s", identityId), paginator.GetAllTokens(ctx)) + + if err != nil { + return nil, v1.NewUnknownError(err.Error()) + } + + paginator.SetTokens(ctx, pageTokens) + metaParam, err := paginator.PaginationHeader(ctx) + if err != nil { + s.core.logger.Errorf("error producing pagination meta param: %s", err) + metaParam = "" + } + + r := new(resources.PaginatedResponse[resources.EntityEntitlement]) + r.Meta = resources.ResponseMeta{Size: len(permissions)} + r.Data = make([]resources.EntityEntitlement, 0) + r.Next.PageToken = &metaParam + + for _, permission := range permissions { + + entity := strings.SplitN(permission.Object, ":", 2) + r.Data = append( + r.Data, + resources.EntityEntitlement{ + Entitlement: permission.Relation, + EntityType: entity[0], + EntityId: entity[1], + }, + ) + } + + return r, nil +} + +// PatchIdentityEntitlements performs addition or removal of an Entitlement to/from an Identity. +func (s *V1Service) PatchIdentityEntitlements(ctx context.Context, identityId string, entitlementPatches []resources.IdentityEntitlementsPatchItem) (bool, error) { + ctx, span := s.core.tracer.Start(ctx, "identities.V1Service.PatchIdentityEntitlements") + defer span.End() + + additions := make([]ofga.Permission, 0) + removals := make([]ofga.Permission, 0) + for _, p := range entitlementPatches { + permission := ofga.Permission{ + Relation: p.Entitlement.Entitlement, + Object: fmt.Sprintf("%s:%s", p.Entitlement.EntityType, p.Entitlement.EntityId), + } + + if p.Op == "add" { + additions = append(additions, permission) + } else if p.Op == "remove" { + removals = append(removals, permission) + } + } + + if len(additions) > 0 { + err := s.openfgaStore.AssignPermissions(ctx, fmt.Sprintf("user:%s", identityId), additions...) + + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + if len(removals) > 0 { + err := s.openfgaStore.UnassignPermissions(ctx, fmt.Sprintf("user:%s", identityId), removals...) + if err != nil { + return false, v1.NewUnknownError(err.Error()) + } + } + + return true, nil +} + +type Config struct { + Name string + Namespace string + K8s coreV1.CoreV1Interface + OpenFGAStore OpenFGAStoreInterface +} + +func NewV1Service(config *Config, svc *Service) *V1Service { + s := new(V1Service) + + s.core = svc + s.k8s = config.K8s + s.cmName = config.Name + s.cmNamespace = config.Namespace + s.openfgaStore = config.OpenFGAStore + + return s +} diff --git a/pkg/identities/service_test.go b/pkg/identities/service_test.go index 915e7478a..f3e196ca6 100644 --- a/pkg/identities/service_test.go +++ b/pkg/identities/service_test.go @@ -12,14 +12,22 @@ import ( reflect "reflect" "testing" + v1 "github.com/canonical/rebac-admin-ui-handlers/v1" + "github.com/canonical/rebac-admin-ui-handlers/v1/interfaces" + "github.com/canonical/rebac-admin-ui-handlers/v1/resources" + "github.com/google/uuid" kClient "github.com/ory/kratos-client-go" "go.opentelemetry.io/otel/trace" gomock "go.uber.org/mock/gomock" + corev1 "k8s.io/api/core/v1" + + ofga "github.com/canonical/identity-platform-admin-ui/internal/openfga" ) //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_logger.go -source=../../internal/logging/interfaces.go //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_interfaces.go -source=./interfaces.go //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_monitor.go -source=../../internal/monitoring/interfaces.go +//go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_corev1.go k8s.io/client-go/kubernetes/typed/core/v1 CoreV1Interface,ConfigMapInterface //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_tracing.go go.opentelemetry.io/otel/trace Tracer //go:generate mockgen -build_flags=--mod=mod -package identities -destination ./mock_kratos.go github.com/ory/kratos-client-go IdentityAPI @@ -45,7 +53,7 @@ func TestListIdentitiesSuccess(t *testing.T) { identities = append(identities, *kClient.NewIdentity(fmt.Sprintf("test-%v", i), "test.json", "https://test.com/test.json", map[string]string{"name": "name"})) } - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.ListIdentities").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().ListIdentities(ctx).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().ListIdentitiesExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIListIdentitiesRequest) ([]kClient.Identity, *http.Response, error) { @@ -108,7 +116,7 @@ func TestListIdentitiesFails(t *testing.T) { identities := make([]kClient.Identity, 0) mockLogger.EXPECT().Error(gomock.Any()).Times(1) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.ListIdentities").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().ListIdentities(ctx).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().ListIdentitiesExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIListIdentitiesRequest) ([]kClient.Identity, *http.Response, error) { @@ -187,7 +195,7 @@ func TestGetIdentitySuccess(t *testing.T) { identity := kClient.NewIdentity(credID, "test.json", "https://test.com/test.json", map[string]string{"name": "name"}) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.GetIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().GetIdentity(ctx, credID).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().GetIdentityExecute(gomock.Any()).Times(1).Return(identity, new(http.Response), nil) @@ -219,7 +227,7 @@ func TestGetIdentityFails(t *testing.T) { } mockLogger.EXPECT().Error(gomock.Any()).Times(1) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.GetIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().GetIdentity(ctx, credID).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().GetIdentityExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIGetIdentityRequest) (*kClient.Identity, *http.Response, error) { @@ -286,7 +294,7 @@ func TestCreateIdentitySuccess(t *testing.T) { identityBody := kClient.NewCreateIdentityBody("test.json", map[string]interface{}{"name": "name"}) identityBody.SetCredentials(*credentials) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.CreateIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockAuthz.EXPECT().SetCreateIdentityEntitlements(gomock.Any(), identity.Id) mockKratosIdentityAPI.EXPECT().CreateIdentity(ctx).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().CreateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( @@ -333,7 +341,7 @@ func TestCreateIdentityFails(t *testing.T) { identityBody.SetCredentials(*credentials) mockLogger.EXPECT().Error(gomock.Any()).Times(1) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.CreateIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().CreateIdentity(ctx).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().CreateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPICreateIdentityRequest) (*kClient.Identity, *http.Response, error) { @@ -401,7 +409,7 @@ func TestUpdateIdentitySuccess(t *testing.T) { identityBody.SetTraits(map[string]interface{}{"name": "name"}) identityBody.SetCredentials(*credentials) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.UpdateIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().UpdateIdentity(ctx, identity.Id).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().UpdateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIUpdateIdentityRequest) (*kClient.Identity, *http.Response, error) { @@ -450,7 +458,7 @@ func TestUpdateIdentityFails(t *testing.T) { identityBody.SetCredentials(*credentials) mockLogger.EXPECT().Error(gomock.Any()).Times(1) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.UpdateIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().UpdateIdentity(ctx, credID).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().UpdateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIUpdateIdentityRequest) (*kClient.Identity, *http.Response, error) { @@ -513,7 +521,7 @@ func TestDeleteIdentitySuccess(t *testing.T) { ApiService: mockKratosIdentityAPI, } - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.DeleteIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockAuthz.EXPECT().SetDeleteIdentityEntitlements(gomock.Any(), credID) mockKratosIdentityAPI.EXPECT().DeleteIdentity(ctx, credID).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().DeleteIdentityExecute(gomock.Any()).Times(1).Return(new(http.Response), nil) @@ -547,7 +555,7 @@ func TestDeleteIdentityFails(t *testing.T) { } mockLogger.EXPECT().Error(gomock.Any()).Times(1) - mockTracer.EXPECT().Start(ctx, "kratos.IdentityAPI.DeleteIdentity").Times(1).Return(ctx, trace.SpanFromContext(ctx)) + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) mockKratosIdentityAPI.EXPECT().DeleteIdentity(ctx, credID).Times(1).Return(identityRequest) mockKratosIdentityAPI.EXPECT().DeleteIdentityExecute(gomock.Any()).Times(1).DoAndReturn( func(r kClient.IdentityAPIDeleteIdentityRequest) (*http.Response, error) { @@ -592,3 +600,1732 @@ func TestDeleteIdentityFails(t *testing.T) { t.Fatal("expected error to be not nil") } } + +func TestV1ServiceImplementsRebacServiceInterface(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var svc interface{} = new(V1Service) + + if _, ok := svc.(interfaces.IdentitiesService); !ok { + t.Fatalf("V1Service doesnt implement interfaces.IdentitiesService") + } +} + +func TestV1ServiceListIdentities(t *testing.T) { + type input struct { + size int + token string + } + + type expected struct { + err error + identities []resources.Identity + } + + kIdentities := make([]kClient.Identity, 0) + identities := make([]resources.Identity, 0) + + for i := 0; i < 10; i++ { + id := uuid.NewString() + name := "Test User" + surname := fmt.Sprintf("%v", i) + email := fmt.Sprintf("test%v@gmail.com", i) + identities = append( + identities, + resources.Identity{ + Id: &id, + Email: email, + FirstName: &name, + LastName: &surname, + }, + ) + kIdentities = append( + kIdentities, + *kClient.NewIdentity( + id, + "test.json", + "https://test.com/test.json", + map[string]string{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ), + ) + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "empty result", + expected: expected{ + identities: []resources.Identity{}, + err: nil, + }, + }, + { + name: "error", + expected: expected{ + identities: nil, + err: fmt.Errorf("Internal Server Error: error"), + }, + }, + { + name: "full result", + input: input{ + size: 1000, + token: "eyJvZmZzZXQiOiIyNTAiLCJ2IjoyfQ", + }, + expected: expected{ + identities: identities, + err: nil, + }, + }, + { + name: "paginated result", + input: input{ + size: 2, + }, + expected: expected{ + identities: identities[:2], + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + identityRequest := kClient.IdentityAPIListIdentitiesRequest{ + ApiService: mockKratosIdentityAPI, + } + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockKratosIdentityAPI.EXPECT().ListIdentities(ctx).Times(1).Return(identityRequest) + mockKratosIdentityAPI.EXPECT().ListIdentitiesExecute(gomock.Any()).Times(1).DoAndReturn( + func(r kClient.IdentityAPIListIdentitiesRequest) ([]kClient.Identity, *http.Response, error) { + + // use reflect as attributes are private, also are pointers so need to cast it multiple times + if pageToken := (*string)(reflect.ValueOf(r).FieldByName("pageToken").UnsafePointer()); *pageToken != test.input.token { + t.Errorf("expected pageToken as %s, got %v", test.input.token, *pageToken) + } + + pageSize := (*int64)(reflect.ValueOf(r).FieldByName("pageSize").UnsafePointer()) + if *pageSize != int64(test.input.size) { + t.Errorf("expected page size as %v, got %v", test.input.size, *pageSize) + } + + if credID := (*string)(reflect.ValueOf(r).FieldByName("credentialsIdentifier").UnsafePointer()); credID != nil { + t.Errorf("expected credential id to be empty, got %v", *credID) + } + + if test.expected.err != nil { + rr := httptest.NewRecorder() + rr.Header().Set("Content-Type", "application/json") + rr.WriteHeader(http.StatusInternalServerError) + + json.NewEncoder(rr).Encode( + map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusInternalServerError, + "debug": "--------", + "details": map[string]interface{}{}, + "id": "string", + "message": "error", + "reason": "error", + "request": "d7ef54b1-ec15-46e6-bccb-524b82c035e6", + "status": "Not Found", + }, + }, + ) + + return []kClient.Identity{}, rr.Result(), fmt.Errorf("error") + } + + rr := new(http.Response) + rr.Header = make(http.Header) + rr.Header.Set("Link", `; rel="first",; rel="next",; rel="prev`) + + if int64(len(kIdentities)) > *pageSize { + return kIdentities[:*pageSize], rr, nil + } + + return kIdentities, rr, nil + + }, + ) + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + r, err := svc.ListIdentities( + ctx, + &resources.GetIdentitiesParams{ + Size: &test.input.size, + NextToken: &test.input.token, + }, + ) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v not %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + for n, i := range r.Data { + if i.Email != test.expected.identities[n].Email { + t.Errorf("expected identities to be %s not %s", test.expected.identities[n].Email, i.Email) + } + + if *i.FirstName != *test.expected.identities[n].FirstName { + t.Errorf("expected name to be %s not %s", *test.expected.identities[n].FirstName, *i.FirstName) + } + + if *i.LastName != *test.expected.identities[n].LastName { + t.Errorf("expected surname to be %s not %s", *test.expected.identities[n].LastName, *i.LastName) + } + } + + if len(r.Data) > 0 && test.input.size > 0 && *r.Next.PageToken != "eyJvZmZzZXQiOiIyNTAiLCJ2IjoyfQ" { + t.Errorf("expected token to be eyJvZmZzZXQiOiIyNTAiLCJ2IjoyfQ, not %s", *r.Next.PageToken) + } + + }) + } +} + +func TestV1ServiceCreateIdentity(t *testing.T) { + type input struct { + identity *resources.Identity + } + + type expected struct { + err error + identity *resources.Identity + } + + id := uuid.NewString() + name := "Test" + surname := "User" + email := "test@gmail.com" + kIdentity := kClient.NewIdentity( + id, + "test", + "https://test.com/test.json", + map[string]string{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ) + identity := resources.Identity{ + Email: email, + FirstName: &name, + LastName: &surname, + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "error", + input: input{ + identity: &identity, + }, + expected: expected{ + + err: v1.NewRequestBodyValidationError("bad identity payload"), + }, + }, + { + name: "success", + input: input{ + identity: &identity, + }, + expected: expected{ + identity: &identity, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockConfigMapV1 := NewMockConfigMapInterface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + ctx := context.Background() + + identityRequest := kClient.IdentityAPICreateIdentityRequest{ + ApiService: mockKratosIdentityAPI, + } + + identityBody := kClient.NewCreateIdentityBody( + kIdentity.SchemaId, + map[string]interface{}{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ) + identityBody.SetState("StateActive") + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockAuthz.EXPECT().SetCreateIdentityEntitlements(gomock.Any(), id).MinTimes(0).MaxTimes(1) + mockCoreV1.EXPECT().ConfigMaps(cfg.Namespace).MinTimes(0).MaxTimes(1).Return(mockConfigMapV1) + mockConfigMapV1.EXPECT().Get(ctx, cfg.Name, gomock.Any()).MinTimes(0).MaxTimes(1).Return(cm, nil) + + mockKratosIdentityAPI.EXPECT().CreateIdentity(gomock.Any()).Times(1).Return(identityRequest) + mockKratosIdentityAPI.EXPECT().CreateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( + func(r kClient.IdentityAPICreateIdentityRequest) (*kClient.Identity, *http.Response, error) { + + // use reflect as attributes are private, also are pointers so need to cast it multiple times + if IDBody := (*kClient.CreateIdentityBody)(reflect.ValueOf(r).FieldByName("createIdentityBody").UnsafePointer()); !reflect.DeepEqual(*IDBody, *identityBody) { + t.Errorf("expected body to be %v, got %v", identityBody, IDBody) + } + + if test.expected.err != nil { + rr := httptest.NewRecorder() + rr.Header().Set("Content-Type", "application/json") + rr.WriteHeader(http.StatusInternalServerError) + + json.NewEncoder(rr).Encode( + map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusInternalServerError, + "debug": "--------", + "details": map[string]interface{}{}, + "id": "string", + "message": "error", + "reason": "error", + "request": "d7ef54b1-ec15-46e6-bccb-524b82c035e6", + "status": "Internal Server Error", + }, + }, + ) + + return nil, rr.Result(), fmt.Errorf("error") + } + + return kIdentity, new(http.Response), nil + }, + ) + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + newIdentity, err := svc.CreateIdentity(ctx, test.input.identity) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v not %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if newIdentity.Id != nil && *newIdentity.Id != id { + t.Errorf("expected ID to be %s, not %s", id, *newIdentity.Id) + } + + if newIdentity.Email != identity.Email { + t.Errorf("expected email to be %s, not %s", identity.Email, newIdentity.Email) + } + + if newIdentity.FirstName != nil && *newIdentity.FirstName != *identity.FirstName { + t.Errorf("expected name to be %s, not %s", *identity.FirstName, *newIdentity.FirstName) + } + + if newIdentity.LastName != nil && *newIdentity.LastName != *identity.LastName { + t.Errorf("expected surname to be %s, not %s", *identity.LastName, *newIdentity.LastName) + } + + }) + } +} + +func TestV1ServiceGetIdentity(t *testing.T) { + type expected struct { + err error + identity *resources.Identity + } + + id := uuid.NewString() + name := "Test" + surname := "User" + email := "test@gmail.com" + kIdentity := kClient.NewIdentity( + id, + "test", + "https://test.com/test.json", + map[string]string{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ) + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "error", + input: uuid.NewString(), + expected: expected{ + err: fmt.Errorf("error"), + }, + }, + { + name: "success", + input: id, + expected: expected{ + identity: &resources.Identity{ + Id: &id, + Email: email, + FirstName: &name, + LastName: &surname, + }, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + identityRequest := kClient.IdentityAPIGetIdentityRequest{ + ApiService: mockKratosIdentityAPI, + } + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockKratosIdentityAPI.EXPECT().GetIdentity(ctx, test.input).Times(1).Return(identityRequest) + mockKratosIdentityAPI.EXPECT().GetIdentityExecute(gomock.Any()).Times(1).DoAndReturn( + func(r kClient.IdentityAPIGetIdentityRequest) (*kClient.Identity, *http.Response, error) { + if test.expected.err != nil { + rr := httptest.NewRecorder() + rr.Header().Set("Content-Type", "application/json") + rr.WriteHeader(http.StatusNotFound) + + json.NewEncoder(rr).Encode( + map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusNotFound, + "debug": "--------", + "details": map[string]interface{}{}, + "id": "string", + "message": "error", + "reason": "error", + "request": "d7ef54b1-ec15-46e6-bccb-524b82c035e6", + "status": "Not Found", + }, + }, + ) + + return nil, rr.Result(), fmt.Errorf("error") + } + + return kIdentity, new(http.Response), nil + }, + ) + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + identity, err := svc.GetIdentity(ctx, test.input) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v not %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if identity.Id != nil && *identity.Id != id { + t.Errorf("expected ID to be %s, not %s", id, *identity.Id) + } + + if identity.Email != test.expected.identity.Email { + t.Errorf("expected email to be %s, not %s", test.expected.identity.Email, identity.Email) + } + + if identity.FirstName != nil && *identity.FirstName != *test.expected.identity.FirstName { + t.Errorf("expected name to be %s, not %s", *test.expected.identity.FirstName, *identity.FirstName) + } + + if identity.LastName != nil && *identity.LastName != *test.expected.identity.LastName { + t.Errorf("expected surname to be %s, not %s", *test.expected.identity.LastName, *identity.LastName) + } + }, + ) + } +} + +func TestV1ServiceUpdateIdentity(t *testing.T) { + type expected struct { + err error + identity *resources.Identity + } + + id := uuid.NewString() + name := "Test" + surname := "User" + email := "test@gmail.com" + kIdentity := kClient.NewIdentity( + id, + "test", + "https://test.com/test.json", + map[string]string{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ) + + tests := []struct { + name string + input *resources.Identity + expected expected + }{ + { + name: "error", + input: &resources.Identity{ + Id: &id, + Email: email, + FirstName: &name, + LastName: &surname, + }, + expected: expected{ + err: fmt.Errorf("error"), + }, + }, + { + name: "success", + input: &resources.Identity{ + Id: &id, + Email: email, + FirstName: &name, + LastName: &surname, + }, + expected: expected{ + identity: &resources.Identity{ + Id: &id, + Email: email, + FirstName: &name, + LastName: &surname, + }, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + identityRequest := kClient.IdentityAPIUpdateIdentityRequest{ + ApiService: mockKratosIdentityAPI, + } + + identityBody := kClient.NewUpdateIdentityBodyWithDefaults() + // identityBody.SetSchemaId(kIdentity.SchemaId) + identityBody.SetTraits(map[string]interface{}{ + "name": fmt.Sprintf("%s %s", name, surname), + "email": email, + }, + ) + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockKratosIdentityAPI.EXPECT().UpdateIdentity(gomock.Any(), *test.input.Id).Times(1).Return(identityRequest) + mockKratosIdentityAPI.EXPECT().UpdateIdentityExecute(gomock.Any()).Times(1).DoAndReturn( + func(r kClient.IdentityAPIUpdateIdentityRequest) (*kClient.Identity, *http.Response, error) { + + // use reflect as attributes are private, also are pointers so need to cast it multiple times + if IDBody := (*kClient.UpdateIdentityBody)(reflect.ValueOf(r).FieldByName("updateIdentityBody").UnsafePointer()); !reflect.DeepEqual(*IDBody, *identityBody) { + t.Errorf("expected body to be %v, got %v", identityBody, IDBody) + } + + if test.expected.err != nil { + rr := httptest.NewRecorder() + rr.Header().Set("Content-Type", "application/json") + rr.WriteHeader(http.StatusNotFound) + + json.NewEncoder(rr).Encode( + map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusConflict, + "debug": "--------", + "details": map[string]interface{}{}, + "id": "string", + "message": "error", + "reason": "error", + "request": "d7ef54b1-ec15-46e6-bccb-524b82c035e6", + "status": "Conflict", + }, + }, + ) + + return nil, rr.Result(), fmt.Errorf("error") + } + + return kIdentity, new(http.Response), nil + }, + ) + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + identity, err := svc.UpdateIdentity(ctx, test.input) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v not %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if identity.Id != nil && *identity.Id != id { + t.Errorf("expected ID to be %s, not %s", id, *identity.Id) + } + + if identity.Email != test.expected.identity.Email { + t.Errorf("expected email to be %s, not %s", test.expected.identity.Email, identity.Email) + } + + if identity.FirstName != nil && *identity.FirstName != *test.expected.identity.FirstName { + t.Errorf("expected name to be %s, not %s", *test.expected.identity.FirstName, *identity.FirstName) + } + + if identity.LastName != nil && *identity.LastName != *test.expected.identity.LastName { + t.Errorf("expected surname to be %s, not %s", *test.expected.identity.LastName, *identity.LastName) + } + }, + ) + } +} + +func TestV1ServiceDeleteIdentity(t *testing.T) { + type expected struct { + err error + ok bool + } + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "error", + input: uuid.NewString(), + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "success", + input: uuid.NewString(), + expected: expected{ + ok: true, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + identityRequest := kClient.IdentityAPIDeleteIdentityRequest{ + ApiService: mockKratosIdentityAPI, + } + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(ctx, gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockAuthz.EXPECT().SetDeleteIdentityEntitlements(gomock.Any(), test.input).MinTimes(0).MaxTimes(1) + mockKratosIdentityAPI.EXPECT().DeleteIdentity(ctx, test.input).Times(1).Return(identityRequest) + mockKratosIdentityAPI.EXPECT().DeleteIdentityExecute(gomock.Any()).Times(1).DoAndReturn( + func(r kClient.IdentityAPIDeleteIdentityRequest) (*http.Response, error) { + if test.expected.err != nil { + rr := httptest.NewRecorder() + rr.Header().Set("Content-Type", "application/json") + rr.WriteHeader(http.StatusNotFound) + + json.NewEncoder(rr).Encode( + map[string]interface{}{ + "error": map[string]interface{}{ + "code": http.StatusNotFound, + "debug": "--------", + "details": map[string]interface{}{}, + "id": "string", + "message": "error", + "reason": "error", + "request": "d7ef54b1-ec15-46e6-bccb-524b82c035e6", + "status": "Not Found", + }, + }, + ) + + return rr.Result(), fmt.Errorf("error") + } + + return new(http.Response), nil + }, + ) + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + ok, err := svc.DeleteIdentity(ctx, test.input) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v not %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if ok != test.expected.ok { + t.Errorf("expected result to be %v, not %v", test.expected.ok, ok) + } + }, + ) + } +} + +func TestV1ServiceGetIdentityGroups(t *testing.T) { + type expected struct { + groups []resources.Group + err error + } + + cLevel := "c-level" + itAdmin := "it-admin" + devops := "devops" + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "empty result", + input: uuid.NewString(), + expected: expected{ + groups: []resources.Group{}, + err: nil, + }, + }, + { + name: "error", + input: uuid.NewString(), + expected: expected{ + err: fmt.Errorf("error"), + }, + }, + { + name: "full result", + input: uuid.NewString(), + expected: expected{ + groups: []resources.Group{ + {Id: &cLevel, Name: cLevel}, + {Id: &itAdmin, Name: itAdmin}, + {Id: &devops, Name: devops}, + }, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().ListAssignedGroups(gomock.Any(), fmt.Sprintf("user:%s", test.input)).DoAndReturn( + func(ctx context.Context, ID string) ([]string, error) { + if test.expected.err != nil { + return nil, fmt.Errorf("error") + } + + groups := make([]string, 0) + + for _, g := range test.expected.groups { + groups = append(groups, g.Name) + } + + return groups, nil + }, + ) + + r, err := svc.GetIdentityGroups(context.Background(), test.input, nil) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + for i, group := range r.Data { + if group.Name != test.expected.groups[i].Name { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.groups[i].Name, group.Name) + } + } + + }) + } +} + +func TestV1ServiceGetIdentityRoles(t *testing.T) { + type expected struct { + roles []resources.Role + err error + } + + cLevel := "c-level" + itAdmin := "it-admin" + devops := "devops" + + tests := []struct { + name string + input string + expected expected + }{ + { + name: "empty result", + input: uuid.NewString(), + expected: expected{ + roles: []resources.Role{}, + err: nil, + }, + }, + { + name: "error", + input: uuid.NewString(), + expected: expected{ + err: fmt.Errorf("error"), + }, + }, + { + name: "full result", + input: uuid.NewString(), + expected: expected{ + roles: []resources.Role{ + {Id: &cLevel, Name: cLevel}, + {Id: &itAdmin, Name: itAdmin}, + {Id: &devops, Name: devops}, + }, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().ListAssignedRoles(gomock.Any(), fmt.Sprintf("user:%s", test.input)).DoAndReturn( + func(ctx context.Context, ID string) ([]string, error) { + if test.expected.err != nil { + return nil, fmt.Errorf("error") + } + + roles := make([]string, 0) + + for _, r := range test.expected.roles { + roles = append(roles, r.Name) + } + + return roles, nil + }, + ) + + r, err := svc.GetIdentityRoles(context.Background(), test.input, nil) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + for i, role := range r.Data { + if role.Name != test.expected.roles[i].Name { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.roles[i].Name, role.Name) + } + } + }) + } +} + +func TestV1ServicePatchIdentityRoles(t *testing.T) { + type input struct { + patches []resources.IdentityRolesPatchItem + id string + } + type expected struct { + ok bool + err error + } + + additions := []resources.IdentityRolesPatchItem{ + {Op: "add", Role: "test1"}, + {Op: "add", Role: "test2"}, + } + removals := []resources.IdentityRolesPatchItem{ + {Op: "remove", Role: "test1"}, + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "empty payload", + input: input{ + id: uuid.NewString(), + patches: []resources.IdentityRolesPatchItem{}, + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + { + name: "error assign", + input: input{ + id: uuid.NewString(), + patches: additions, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "error unassign", + input: input{ + id: uuid.NewString(), + patches: removals, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "success", + input: input{ + id: uuid.NewString(), + patches: append(removals, additions...), + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + // AssignRoles(context.Context, string, ...string) error + // UnassignRoles(context.Context, string, ...string) error + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().AssignRoles(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, roles ...string) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + rs := make([]string, 0) + + for _, r := range test.input.patches { + if r.Op == "add" { + rs = append(rs, fmt.Sprintf("role:%s", r.Role)) + } + } + + if !reflect.DeepEqual(rs, roles) { + t.Errorf("expected roles to be %v got %v", rs, roles) + } + + return nil + }, + ) + + mockOpenFGAStore.EXPECT().UnassignRoles(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, roles ...string) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + rs := make([]string, 0) + + for _, r := range test.input.patches { + if r.Op == "remove" { + rs = append(rs, fmt.Sprintf("role:%s", r.Role)) + } + } + + if !reflect.DeepEqual(rs, roles) { + t.Errorf("expected roles to be %v got %v", rs, roles) + } + + return nil + }, + ) + + ok, err := svc.PatchIdentityRoles(context.Background(), test.input.id, test.input.patches) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if ok != test.expected.ok { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.ok, ok) + } + }) + } +} + +func TestV1ServicePatchIdentityGroups(t *testing.T) { + type input struct { + patches []resources.IdentityGroupsPatchItem + id string + } + type expected struct { + ok bool + err error + } + + additions := []resources.IdentityGroupsPatchItem{ + {Op: "add", Group: "test1"}, + {Op: "add", Group: "test2"}, + } + removals := []resources.IdentityGroupsPatchItem{ + {Op: "remove", Group: "test1"}, + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "empty payload", + input: input{ + id: uuid.NewString(), + patches: []resources.IdentityGroupsPatchItem{}, + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + { + name: "error assign", + input: input{ + id: uuid.NewString(), + patches: additions, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "error unassign", + input: input{ + id: uuid.NewString(), + patches: removals, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "success", + input: input{ + id: uuid.NewString(), + patches: append(removals, additions...), + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + // AssignGroups(context.Context, string, ...string) error + // UnassignGroups(context.Context, string, ...string) error + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().AssignGroups(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, groups ...string) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + gs := make([]string, 0) + + for _, g := range test.input.patches { + if g.Op == "add" { + gs = append(gs, fmt.Sprintf("group:%s", g.Group)) + } + } + + if !reflect.DeepEqual(gs, groups) { + t.Errorf("expected groups to be %v got %v", gs, groups) + } + + return nil + }, + ) + + mockOpenFGAStore.EXPECT().UnassignGroups(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, groups ...string) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + gs := make([]string, 0) + + for _, g := range test.input.patches { + if g.Op == "remove" { + gs = append(gs, fmt.Sprintf("group:%s", g.Group)) + } + } + + if !reflect.DeepEqual(gs, groups) { + t.Errorf("expected groups to be %v got %v", gs, groups) + } + + return nil + }, + ) + + ok, err := svc.PatchIdentityGroups(context.Background(), test.input.id, test.input.patches) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if ok != test.expected.ok { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.ok, ok) + } + }) + } +} + +func TestV1ServiceGetIdentityEntitlements(t *testing.T) { + type input struct { + params *resources.GetIdentitiesItemEntitlementsParams + id string + } + type expected struct { + permissions []resources.EntityEntitlement + err error + } + + permissions := []resources.EntityEntitlement{ + { + Entitlement: "can_view", + EntityId: "okta", + EntityType: "client", + }, + { + Entitlement: "can_delete", + EntityId: "github", + EntityType: "client", + }, + { + Entitlement: "can_create", + EntityId: "github", + EntityType: "client", + }, + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "empty payload", + input: input{ + id: uuid.NewString(), + }, + expected: expected{ + permissions: []resources.EntityEntitlement{}, + err: nil, + }, + }, + { + name: "error", + input: input{ + id: uuid.NewString(), + }, + expected: expected{ + err: fmt.Errorf("error"), + }, + }, + { + name: "success", + input: input{ + id: uuid.NewString(), + }, + expected: expected{ + permissions: permissions, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().ListPermissions(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).Times(1).DoAndReturn( + func(ctx context.Context, ID string, tokens map[string]string) ([]ofga.Permission, map[string]string, error) { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return nil, nil, fmt.Errorf("error") + } + + ps := make([]ofga.Permission, 0) + + for _, p := range test.expected.permissions { + ps = append( + ps, + ofga.Permission{ + Relation: p.Entitlement, + Object: fmt.Sprintf("%s:%s", p.EntityType, p.EntityId), + }, + ) + } + return ps, map[string]string{}, nil + }, + ) + + r, err := svc.GetIdentityEntitlements(context.Background(), test.input.id, test.input.params) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if !reflect.DeepEqual(r.Data, test.expected.permissions) { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.permissions, r.Data) + } + }) + } +} + +func TestV1ServicePatchIdentityEntitlements(t *testing.T) { + type input struct { + patches []resources.IdentityEntitlementsPatchItem + id string + } + type expected struct { + ok bool + err error + } + + additions := []resources.IdentityEntitlementsPatchItem{ + { + Op: "add", + Entitlement: resources.EntityEntitlement{ + Entitlement: "can_view", + EntityId: "okta", + EntityType: "client", + }, + }, + { + Op: "add", + Entitlement: resources.EntityEntitlement{ + Entitlement: "can_delete", + EntityId: "github", + EntityType: "client", + }, + }, + } + + removals := []resources.IdentityEntitlementsPatchItem{ + { + Op: "remove", + Entitlement: resources.EntityEntitlement{ + Entitlement: "can_create", + EntityId: "github", + EntityType: "client", + }, + }, + } + + tests := []struct { + name string + input input + expected expected + }{ + { + name: "empty payload", + input: input{ + id: uuid.NewString(), + patches: []resources.IdentityEntitlementsPatchItem{}, + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + { + name: "error assign", + input: input{ + id: uuid.NewString(), + patches: additions, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "error unassign", + input: input{ + id: uuid.NewString(), + patches: removals, + }, + expected: expected{ + err: fmt.Errorf("error"), + ok: false, + }, + }, + { + name: "success", + input: input{ + id: uuid.NewString(), + patches: append(removals, additions...), + }, + expected: expected{ + ok: true, + err: nil, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := NewMockMonitorInterface(ctrl) + mockCoreV1 := NewMockCoreV1Interface(ctrl) + mockAuthz := NewMockAuthorizerInterface(ctrl) + mockKratosIdentityAPI := NewMockIdentityAPI(ctrl) + mockOpenFGAStore := NewMockOpenFGAStoreInterface(ctrl) + + ctx := context.Background() + + cfg := new(Config) + cfg.K8s = mockCoreV1 + cfg.Name = "schemas" + cfg.Namespace = "default" + cfg.OpenFGAStore = mockOpenFGAStore + + cm := new(corev1.ConfigMap) + cm.Data = make(map[string]string) + cm.Data[DEFAULT_SCHEMA] = "test" + + svc := NewV1Service( + cfg, + NewService(mockKratosIdentityAPI, mockAuthz, mockTracer, mockMonitor, mockLogger), + ) + + // AssignGroups(context.Context, string, ...string) error + // UnassignGroups(context.Context, string, ...string) error + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(ctx, trace.SpanFromContext(ctx)) + mockOpenFGAStore.EXPECT().AssignPermissions(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, permissions ...ofga.Permission) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + ps := make([]ofga.Permission, 0) + + for _, p := range test.input.patches { + if p.Op == "add" { + ps = append( + ps, + ofga.Permission{ + Relation: p.Entitlement.Entitlement, + Object: fmt.Sprintf("%s:%s", p.Entitlement.EntityType, p.Entitlement.EntityId), + }, + ) + } + } + + if !reflect.DeepEqual(ps, permissions) { + t.Errorf("expected permissions to be %v got %v", ps, permissions) + } + + return nil + }, + ) + + mockOpenFGAStore.EXPECT().UnassignPermissions(gomock.Any(), fmt.Sprintf("user:%s", test.input.id), gomock.Any()).MinTimes(0).MaxTimes(1).DoAndReturn( + func(ctx context.Context, ID string, permissions ...ofga.Permission) error { + if ID != fmt.Sprintf("user:%s", test.input.id) { + t.Errorf("expected ID to be user:%s got %s", test.input.id, ID) + } + + if test.expected.err != nil { + return fmt.Errorf("error") + } + + ps := make([]ofga.Permission, 0) + + for _, p := range test.input.patches { + if p.Op == "remove" { + ps = append( + ps, + ofga.Permission{ + Relation: p.Entitlement.Entitlement, + Object: fmt.Sprintf("%s:%s", p.Entitlement.EntityType, p.Entitlement.EntityId), + }, + ) + } + } + + if !reflect.DeepEqual(ps, permissions) { + t.Errorf("expected permissions to be %v got %v", ps, permissions) + } + + return nil + }, + ) + + ok, err := svc.PatchIdentityEntitlements(context.Background(), test.input.id, test.input.patches) + + if test.expected.err != nil && err == nil { + t.Errorf("expected error to be %v got %v", test.expected.err, err) + } + + if test.expected.err != nil { + return + } + + if ok != test.expected.ok { + t.Errorf("invalid result, expected: %v, got: %v", test.expected.ok, ok) + } + }) + } +}