Skip to content

Commit

Permalink
feat: Add AuthMiddleware middleware to server (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaikelVeen authored Jan 25, 2025
1 parent 42cef10 commit 1099b98
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 11 deletions.
22 changes: 18 additions & 4 deletions cmd/server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@ package server
import (
"fmt"
"log/slog"
"net/http"
"os"
"os/user"

"github.com/glass-cms/glasscms/internal/auth"
authRepository "github.com/glass-cms/glasscms/internal/auth/repository"
"github.com/glass-cms/glasscms/internal/database"
"github.com/glass-cms/glasscms/internal/item"
"github.com/glass-cms/glasscms/internal/item/repository"
itemRepository "github.com/glass-cms/glasscms/internal/item/repository"
"github.com/glass-cms/glasscms/internal/server"
internalMiddleware "github.com/glass-cms/glasscms/internal/server/middleware"
ctx "github.com/glass-cms/glasscms/pkg/context"
"github.com/glass-cms/glasscms/pkg/log"
"github.com/glass-cms/glasscms/pkg/mediatype"
"github.com/glass-cms/glasscms/pkg/middleware"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -91,10 +97,18 @@ func (c *StartCommand) Execute(cmd *cobra.Command, _ []string) error {
return err
}

repo := repository.NewRepository(db, errHandler)
service := item.NewService(repo)
itemRepo := itemRepository.NewRepository(db, errHandler)
itemService := item.NewService(itemRepo)

server, err := server.New(logger, service)
authRepo := authRepository.NewRepository(db, errHandler)
authService := auth.NewAuth(db, authRepo, logger)

server, err := server.New(logger, itemService, []func(http.Handler) http.Handler{
middleware.RequestID,
middleware.ContentType(mediatype.ApplicationJSON),
middleware.Accept(mediatype.ApplicationJSON),
internalMiddleware.AuthMiddleware(authService),
})
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions internal/server/item_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func TestAPIHandler_ItemsCreate(t *testing.T) {
handler, err := server.New(
log.NoopLogger(),
item.NewService(repo),
[]func(http.Handler) http.Handler{},
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -109,6 +110,7 @@ func TestAPIHandler_ItemsGet(t *testing.T) {
server, err := server.New(
log.NoopLogger(),
item.NewService(repo),
[]func(http.Handler) http.Handler{},
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -202,6 +204,7 @@ func TestAPIHandler_ItemsList(t *testing.T) {
handler, err := server.New(
log.NoopLogger(),
svc,
[]func(http.Handler) http.Handler{},
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -285,6 +288,7 @@ func TestAPIHandler_ItemsUpsert(t *testing.T) {
handler, err := server.New(
log.NoopLogger(),
svc,
[]func(http.Handler) http.Handler{},
)
if err != nil {
t.Fatal(err)
Expand Down
37 changes: 37 additions & 0 deletions internal/server/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package middleware

import (
"context"
"net/http"
)

//go:generate moq -out mock_auth.go . Authentication

type Authentication interface {
ValidateToken(ctx context.Context, token string) (bool, error)
}

// AuthMiddleware creates an http middleware that validates auth tokens in requests.
//
// It takes an Authentication interface and returns a middleware function that checks
// for valid Authorization header tokens, responding with 401 Unauthorized if
// validation fails.
func AuthMiddleware(auth Authentication) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token == "" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

valid, err := auth.ValidateToken(r.Context(), token)
if err != nil || !valid {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
})
}
}
59 changes: 59 additions & 0 deletions internal/server/middleware/auth_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package middleware_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/glass-cms/glasscms/internal/server/middleware"
"github.com/stretchr/testify/assert"
)

func TestAuthMiddleware(t *testing.T) {
mockAuth := &middleware.AuthenticationMock{}

handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

middleware := middleware.AuthMiddleware(mockAuth)
wrappedHandler := middleware(handler)

t.Run("Valid Token", func(t *testing.T) {
mockAuth.ValidateTokenFunc = func(_ context.Context, _ string) (bool, error) {
return true, nil
}

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "valid-token")
w := httptest.NewRecorder()

wrappedHandler.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
})

t.Run("Invalid Token", func(t *testing.T) {
mockAuth.ValidateTokenFunc = func(_ context.Context, _ string) (bool, error) {
return false, nil
}

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "invalid-token")
w := httptest.NewRecorder()

wrappedHandler.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)
})

t.Run("Missing Token", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()

wrappedHandler.ServeHTTP(w, req)

assert.Equal(t, http.StatusUnauthorized, w.Code)
})
}
2 changes: 2 additions & 0 deletions internal/server/middleware/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package middleware contains application specific middleware.
package middleware
81 changes: 81 additions & 0 deletions internal/server/middleware/mock_auth.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 1 addition & 7 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"github.com/glass-cms/glasscms/internal/item"
"github.com/glass-cms/glasscms/pkg/api"
"github.com/glass-cms/glasscms/pkg/fieldmask"
"github.com/glass-cms/glasscms/pkg/mediatype"
"github.com/glass-cms/glasscms/pkg/middleware"
"github.com/glass-cms/glasscms/pkg/resource"
)

Expand All @@ -38,6 +36,7 @@ type Server struct {
func New(
logger *slog.Logger,
itemService *item.Service,
middlewares []func(http.Handler) http.Handler,
opts ...Option,
) (*Server, error) {
serveMux := http.NewServeMux()
Expand All @@ -48,11 +47,6 @@ func New(
errorHandler: NewErrorHandler(),
}

middlewares := []func(http.Handler) http.Handler{
middleware.RequestID,
middleware.ContentType(mediatype.ApplicationJSON),
middleware.Accept(mediatype.ApplicationJSON),
}
convertedMiddlewares := make([]api.MiddlewareFunc, len(middlewares))
for i, mw := range middlewares {
convertedMiddlewares[i] = api.MiddlewareFunc(mw)
Expand Down

0 comments on commit 1099b98

Please sign in to comment.