diff --git a/internal/server.go b/internal/server.go index b95bf9d..caf58e3 100644 --- a/internal/server.go +++ b/internal/server.go @@ -57,7 +57,7 @@ func NewBookclubServer(client Client, repository BookRepository, userRepository func jwtMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - tokenString := extractToken(r) //TODO: extract token + tokenString := extractToken(r) if tokenString == "" { respondWithError(w, http.StatusUnauthorized, "Missing JWT") return @@ -79,7 +79,7 @@ func extractToken(r *http.Request) string { } func validateToken(tokenString string) (jwt.RegisteredClaims, error) { - jwtSecret := os.Getenv("JWT_SECRET") //TODO: load from .env + jwtSecret := os.Getenv("JWT_SECRET") claims := jwt.RegisteredClaims{} token, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil diff --git a/internal/server_test.go b/internal/server_test.go index 9bc708c..5201410 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -1,28 +1,61 @@ package internal import ( + "encoding/json" "net/http" - "reflect" + "net/http/httptest" "testing" ) -func Test_jwtMiddleware(t *testing.T) { +func Test_jwtMiddleware_missing_jwt(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + mw := jwtMiddleware(handler) - type args struct { - next http.HandlerFunc + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + mw(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + + var got responseError + err := json.NewDecoder(resp.Body).Decode(&got) + if err != nil { + t.Fatalf("Unable to parse response from server %q, '%v'", resp.Body, err) + } + + if got.Error != "Missing JWT" { + t.Errorf("expected error message 'Missing JWT', got %q", got.Error) + } +} + +func Test_jwtMiddleware_invalid_jwt(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + mw := jwtMiddleware(handler) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Add("Authorization", "Bearer invalid") + w := httptest.NewRecorder() + mw(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode) } - tests := []struct { - name string - args args - want http.HandlerFunc - }{ - // TODO:: Add test cases. + + var got responseError + err := json.NewDecoder(resp.Body).Decode(&got) + if err != nil { + t.Fatalf("Unable to parse response from server %q, '%v'", resp.Body, err) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := jwtMiddleware(tt.args.next); !reflect.DeepEqual(got, tt.want) { - t.Errorf("jwtMiddleware() = %v, want %v", got, tt.want) - } - }) + + if got.Error != "Invalid JWT" { + t.Errorf("expected error message 'Missing JWT', got %q", got.Error) } } + +type responseError struct { + Error string `json:"error"` +}