diff --git a/cmd/server/start.go b/cmd/server/start.go index 998386e..94ffffe 100644 --- a/cmd/server/start.go +++ b/cmd/server/start.go @@ -3,15 +3,21 @@ package server import ( "fmt" "log/slog" + "net/http" "os" "os/user" + "github.com/glass-cms/glasscms/internal/auth" + authRepository "github.com/glass-cms/glasscms/internal/auth/repository" "github.com/glass-cms/glasscms/internal/database" "github.com/glass-cms/glasscms/internal/item" - "github.com/glass-cms/glasscms/internal/item/repository" + itemRepository "github.com/glass-cms/glasscms/internal/item/repository" "github.com/glass-cms/glasscms/internal/server" + internalMiddleware "github.com/glass-cms/glasscms/internal/server/middleware" ctx "github.com/glass-cms/glasscms/pkg/context" "github.com/glass-cms/glasscms/pkg/log" + "github.com/glass-cms/glasscms/pkg/mediatype" + "github.com/glass-cms/glasscms/pkg/middleware" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -91,10 +97,18 @@ func (c *StartCommand) Execute(cmd *cobra.Command, _ []string) error { return err } - repo := repository.NewRepository(db, errHandler) - service := item.NewService(repo) + itemRepo := itemRepository.NewRepository(db, errHandler) + itemService := item.NewService(itemRepo) - server, err := server.New(logger, service) + authRepo := authRepository.NewRepository(db, errHandler) + authService := auth.NewAuth(db, authRepo, logger) + + server, err := server.New(logger, itemService, []func(http.Handler) http.Handler{ + middleware.RequestID, + middleware.ContentType(mediatype.ApplicationJSON), + middleware.Accept(mediatype.ApplicationJSON), + internalMiddleware.AuthMiddleware(authService), + }) if err != nil { return err } diff --git a/internal/server/item_test.go b/internal/server/item_test.go index 9470715..4d82c31 100644 --- a/internal/server/item_test.go +++ b/internal/server/item_test.go @@ -59,6 +59,7 @@ func TestAPIHandler_ItemsCreate(t *testing.T) { handler, err := server.New( log.NoopLogger(), item.NewService(repo), + []func(http.Handler) http.Handler{}, ) if err != nil { t.Fatal(err) @@ -109,6 +110,7 @@ func TestAPIHandler_ItemsGet(t *testing.T) { server, err := server.New( log.NoopLogger(), item.NewService(repo), + []func(http.Handler) http.Handler{}, ) if err != nil { t.Fatal(err) @@ -202,6 +204,7 @@ func TestAPIHandler_ItemsList(t *testing.T) { handler, err := server.New( log.NoopLogger(), svc, + []func(http.Handler) http.Handler{}, ) if err != nil { t.Fatal(err) @@ -285,6 +288,7 @@ func TestAPIHandler_ItemsUpsert(t *testing.T) { handler, err := server.New( log.NoopLogger(), svc, + []func(http.Handler) http.Handler{}, ) if err != nil { t.Fatal(err) diff --git a/internal/server/middleware/auth.go b/internal/server/middleware/auth.go new file mode 100644 index 0000000..ed93de7 --- /dev/null +++ b/internal/server/middleware/auth.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "context" + "net/http" +) + +//go:generate moq -out mock_auth.go . Authentication + +type Authentication interface { + ValidateToken(ctx context.Context, token string) (bool, error) +} + +// AuthMiddleware creates an http middleware that validates auth tokens in requests. +// +// It takes an Authentication interface and returns a middleware function that checks +// for valid Authorization header tokens, responding with 401 Unauthorized if +// validation fails. +func AuthMiddleware(auth Authentication) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + valid, err := auth.ValidateToken(r.Context(), token) + if err != nil || !valid { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/server/middleware/auth_middleware_test.go b/internal/server/middleware/auth_middleware_test.go new file mode 100644 index 0000000..447d04a --- /dev/null +++ b/internal/server/middleware/auth_middleware_test.go @@ -0,0 +1,59 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/glass-cms/glasscms/internal/server/middleware" + "github.com/stretchr/testify/assert" +) + +func TestAuthMiddleware(t *testing.T) { + mockAuth := &middleware.AuthenticationMock{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := middleware.AuthMiddleware(mockAuth) + wrappedHandler := middleware(handler) + + t.Run("Valid Token", func(t *testing.T) { + mockAuth.ValidateTokenFunc = func(_ context.Context, _ string) (bool, error) { + return true, nil + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "valid-token") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("Invalid Token", func(t *testing.T) { + mockAuth.ValidateTokenFunc = func(_ context.Context, _ string) (bool, error) { + return false, nil + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "invalid-token") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("Missing Token", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) +} diff --git a/internal/server/middleware/doc.go b/internal/server/middleware/doc.go new file mode 100644 index 0000000..46edd2a --- /dev/null +++ b/internal/server/middleware/doc.go @@ -0,0 +1,2 @@ +// Package middleware contains application specific middleware. +package middleware diff --git a/internal/server/middleware/mock_auth.go b/internal/server/middleware/mock_auth.go new file mode 100644 index 0000000..9ddcbc6 --- /dev/null +++ b/internal/server/middleware/mock_auth.go @@ -0,0 +1,81 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package middleware + +import ( + "context" + "sync" +) + +// Ensure, that AuthenticationMock does implement Authentication. +// If this is not the case, regenerate this file with moq. +var _ Authentication = &AuthenticationMock{} + +// AuthenticationMock is a mock implementation of Authentication. +// +// func TestSomethingThatUsesAuthentication(t *testing.T) { +// +// // make and configure a mocked Authentication +// mockedAuthentication := &AuthenticationMock{ +// ValidateTokenFunc: func(ctx context.Context, token string) (bool, error) { +// panic("mock out the ValidateToken method") +// }, +// } +// +// // use mockedAuthentication in code that requires Authentication +// // and then make assertions. +// +// } +type AuthenticationMock struct { + // ValidateTokenFunc mocks the ValidateToken method. + ValidateTokenFunc func(ctx context.Context, token string) (bool, error) + + // calls tracks calls to the methods. + calls struct { + // ValidateToken holds details about calls to the ValidateToken method. + ValidateToken []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + } + } + lockValidateToken sync.RWMutex +} + +// ValidateToken calls ValidateTokenFunc. +func (mock *AuthenticationMock) ValidateToken(ctx context.Context, token string) (bool, error) { + if mock.ValidateTokenFunc == nil { + panic("AuthenticationMock.ValidateTokenFunc: method is nil but Authentication.ValidateToken was just called") + } + callInfo := struct { + Ctx context.Context + Token string + }{ + Ctx: ctx, + Token: token, + } + mock.lockValidateToken.Lock() + mock.calls.ValidateToken = append(mock.calls.ValidateToken, callInfo) + mock.lockValidateToken.Unlock() + return mock.ValidateTokenFunc(ctx, token) +} + +// ValidateTokenCalls gets all the calls that were made to ValidateToken. +// Check the length with: +// +// len(mockedAuthentication.ValidateTokenCalls()) +func (mock *AuthenticationMock) ValidateTokenCalls() []struct { + Ctx context.Context + Token string +} { + var calls []struct { + Ctx context.Context + Token string + } + mock.lockValidateToken.RLock() + calls = mock.calls.ValidateToken + mock.lockValidateToken.RUnlock() + return calls +} diff --git a/internal/server/server.go b/internal/server/server.go index 66ca46d..bb5cf35 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,8 +11,6 @@ import ( "github.com/glass-cms/glasscms/internal/item" "github.com/glass-cms/glasscms/pkg/api" "github.com/glass-cms/glasscms/pkg/fieldmask" - "github.com/glass-cms/glasscms/pkg/mediatype" - "github.com/glass-cms/glasscms/pkg/middleware" "github.com/glass-cms/glasscms/pkg/resource" ) @@ -38,6 +36,7 @@ type Server struct { func New( logger *slog.Logger, itemService *item.Service, + middlewares []func(http.Handler) http.Handler, opts ...Option, ) (*Server, error) { serveMux := http.NewServeMux() @@ -48,11 +47,6 @@ func New( errorHandler: NewErrorHandler(), } - middlewares := []func(http.Handler) http.Handler{ - middleware.RequestID, - middleware.ContentType(mediatype.ApplicationJSON), - middleware.Accept(mediatype.ApplicationJSON), - } convertedMiddlewares := make([]api.MiddlewareFunc, len(middlewares)) for i, mw := range middlewares { convertedMiddlewares[i] = api.MiddlewareFunc(mw)