Skip to content

Commit

Permalink
Merge pull request #778 from Permify/tests
Browse files Browse the repository at this point in the history
Tests
  • Loading branch information
tolgaOzen authored Oct 25, 2023
2 parents 4f3c10d + b8a2eac commit 947191a
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 5 deletions.
2 changes: 1 addition & 1 deletion internal/authn/oidc/fakes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (s *fakeOidcProvider) SignIDToken(unsignedToken *jwt.Token) (string, error)
signedToken, err = unsignedToken.SignedString(s.rsaPrivateKeyForPS)

default:
return "", fmt.Errorf("Incorrect signing method type, supported algorithms: HS256, RS256, ES256, PS256")
return "", fmt.Errorf("incorrect signing method type, supported algorithms: HS256, RS256, ES256, PS256")
}

if err != nil {
Expand Down
20 changes: 16 additions & 4 deletions internal/authn/oidc/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,51 @@ import (
"google.golang.org/grpc"
)

// UnaryServerInterceptor -
// UnaryServerInterceptor returns a gRPC unary server interceptor that
// performs authentication using the provided Authenticator.
func UnaryServerInterceptor(t Authenticator) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// Authenticate the request.
err := t.Authenticate(ctx)
if err != nil {
// If authentication fails, return the error.
return nil, err
}
// If authentication succeeds, proceed with the request.
return handler(ctx, req)
}
}

// StreamServerInterceptor -
// StreamServerInterceptor returns a gRPC stream server interceptor that
// wraps the incoming stream with an authenticator.
func StreamServerInterceptor(t Authenticator) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
// Wrap the stream with the authenticator.
wrapper := &authnWrapper{ServerStream: stream, authenticator: t}
return handler(srv, wrapper)
}
}

// authnWrapper -
// authnWrapper wraps a grpc.ServerStream and intercepts its RecvMsg
// method to perform authentication on each message received.
type authnWrapper struct {
grpc.ServerStream
authenticator Authenticator
}

// RecvMsg -
// RecvMsg intercepts the RecvMsg call of the wrapped grpc.ServerStream
// to perform authentication before processing the message.
func (s *authnWrapper) RecvMsg(req interface{}) error {
// Receive the message from the original stream.
if err := s.ServerStream.RecvMsg(req); err != nil {
return err
}
// Authenticate the received message.
err := s.authenticator.Authenticate(s.ServerStream.Context())
if err != nil {
// If authentication fails, return the error.
return err
}
// If authentication succeeds, proceed with the message.
return nil
}
164 changes: 164 additions & 0 deletions internal/authn/oidc/interceptors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package oidc

import (
"context"
"errors"
"testing"

"google.golang.org/grpc/metadata"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"google.golang.org/grpc"
)

func TestAuthInterceptors(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "authentication interceptors suite")
}

var _ = Describe("Auth Interceptors", func() {
fakeError := errors.New("fake authentication error")

Describe("UnaryServerInterceptor", func() {
var authenticator Authenticator
var interceptor grpc.UnaryServerInterceptor
var handlerCalled bool

BeforeEach(func() {
handlerCalled = false
})

Context("when authentication is successful", func() {
BeforeEach(func() {
authenticator = &mockAuthenticator{err: nil}
interceptor = UnaryServerInterceptor(authenticator)
})

It("should call the handler and not return an error", func() {
_, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return "success", nil
})
Expect(err).NotTo(HaveOccurred())
Expect(handlerCalled).To(BeTrue())
})
})

Context("when authentication fails", func() {
BeforeEach(func() {
authenticator = &mockAuthenticator{err: fakeError}
interceptor = UnaryServerInterceptor(authenticator)
})

It("should not call the handler and return an error", func() {
_, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return nil, nil
})
Expect(err).To(MatchError(fakeError))
Expect(handlerCalled).To(BeFalse())
})
})
})

Describe("StreamServerInterceptor", func() {
var authenticator Authenticator
var interceptor grpc.StreamServerInterceptor
var handlerCalled bool
var mockStream *mockServerStream

BeforeEach(func() {
handlerCalled = false
mockStream = &mockServerStream{}
})

Context("when authentication is successful", func() {
BeforeEach(func() {
authenticator = &mockAuthenticator{err: nil}
interceptor = StreamServerInterceptor(authenticator)
})

It("should call the handler and not return an error", func() {
err := interceptor(nil, mockStream, nil, func(srv interface{}, stream grpc.ServerStream) error {
handlerCalled = true
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(handlerCalled).To(BeTrue())
})
})
})

Describe("authnWrapper", func() {
var wrapper *authnWrapper
var mockStream *mockServerStream
var authenticator Authenticator

BeforeEach(func() {
mockStream = &mockServerStream{}
})

Context("when authentication is successful", func() {
BeforeEach(func() {
authenticator = &mockAuthenticator{err: nil}
wrapper = &authnWrapper{ServerStream: mockStream, authenticator: authenticator}
})

It("should call the original RecvMsg and not return an error", func() {
err := wrapper.RecvMsg(nil)
Expect(err).NotTo(HaveOccurred())
Expect(mockStream.recvMsgCalled).To(BeTrue())
})
})

Context("when authentication fails", func() {
BeforeEach(func() {
authenticator = &mockAuthenticator{err: fakeError}
wrapper = &authnWrapper{ServerStream: mockStream, authenticator: authenticator}
})

It("should return an error without processing the message", func() {
err := wrapper.RecvMsg(nil)
Expect(err).To(MatchError(fakeError))
Expect(mockStream.recvMsgCalled).To(BeTrue())
})
})
})
})

// mockServerStream is a fake implementation of the grpc.ServerStream for testing.
type mockServerStream struct {
recvMsgCalled bool
}

func (m *mockServerStream) SetHeader(md metadata.MD) error {
return nil
}

func (m *mockServerStream) SendHeader(md metadata.MD) error {
return nil
}

func (m *mockServerStream) SetTrailer(md metadata.MD) {}

func (m *mockServerStream) Context() context.Context {
return context.Background()
}

func (m *mockServerStream) SendMsg(a any) error {
return nil
}

func (m *mockServerStream) RecvMsg(x interface{}) error {
m.recvMsgCalled = true
return nil
}

type mockAuthenticator struct {
err error
}

func (m *mockAuthenticator) Authenticate(ctx context.Context) error {
return m.err
}
75 changes: 75 additions & 0 deletions internal/authn/preshared/authn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package preshared

import (
"context"
"testing"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/Permify/permify/internal/config"
base "github.com/Permify/permify/pkg/pb/base/v1"
)

func TestPresharedKeyAuth(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "authentication preshared key suite")
}

var _ = Describe("KeyAuthn", func() {
var (
ctx context.Context
authenticator *KeyAuthn
err error
keysConfig config.Preshared
)

BeforeEach(func() {
keysConfig = config.Preshared{Keys: []string{"key1", "key2"}}
authenticator, err = NewKeyAuthn(context.Background(), keysConfig)
Expect(err).ToNot(HaveOccurred())
})

Describe("Authenticate", func() {
Context("with valid Bearer token", func() {
BeforeEach(func() {
md := metadata.New(map[string]string{"authorization": "Bearer key1"})
ctx = metadata.NewIncomingContext(context.Background(), md)
})

It("should authenticate successfully", func() {
err := authenticator.Authenticate(ctx)
Expect(err).ToNot(HaveOccurred())
})
})

Context("with invalid Bearer token", func() {
BeforeEach(func() {
md := metadata.New(map[string]string{"authorization": "Bearer invalidkey"})
ctx = metadata.NewIncomingContext(context.Background(), md)
})

It("should return an error", func() {
err := authenticator.Authenticate(ctx)
Expect(err).To(HaveOccurred())
Expect(status.Code(err)).To(Equal(codes.Unauthenticated))
})
})

Context("with missing Bearer token", func() {
BeforeEach(func() {
ctx = context.Background()
})

It("should return an error", func() {
err := authenticator.Authenticate(ctx)
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(Equal(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String()))
})
})
})
})
Loading

0 comments on commit 947191a

Please sign in to comment.