Skip to content

Commit

Permalink
IAM: Set Cache-Control headers (#3147)
Browse files Browse the repository at this point in the history
  • Loading branch information
reinkrul authored May 31, 2024
1 parent 340e015 commit 050e189
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 10 deletions.
19 changes: 19 additions & 0 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/http/cache"
"github.com/nuts-foundation/nuts-node/http/user"
"html/template"
"net/http"
Expand Down Expand Up @@ -71,6 +72,22 @@ const accessTokenValidity = 15 * time.Minute

const oid4vciSessionValidity = 15 * time.Minute

// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
var cacheControlMaxAgeURLs = []string{
"/.well-known/did.json",
"/iam/:id/did.json",
"/oauth2/:did/presentation_definition",
"/.well-known/oauth-authorization-server/iam/:id",
"/.well-known/oauth-authorization-server",
"/oauth2/:did/oauth-client",
"/statuslist/:did/:page",
}

// cacheControlNoCacheURLs holds API endpoints that should have a no-cache cache control header set.
var cacheControlNoCacheURLs = []string{
"/oauth2/:did/token",
}

//go:embed assets
var assetsFS embed.FS

Expand Down Expand Up @@ -130,6 +147,8 @@ func (r Wrapper) Routes(router core.EchoRouter) {
return next(c)
}
}, audit.Middleware(apiModuleName))
router.Use(cache.MaxAge(5*time.Minute, cacheControlMaxAgeURLs...).Handle)
router.Use(cache.NoCache(cacheControlNoCacheURLs...).Handle)
router.Use(user.SessionMiddleware{
Skipper: func(c echo.Context) bool {
// The following URLs require a user session:
Expand Down
49 changes: 39 additions & 10 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,16 +698,45 @@ func TestWrapper_IntrospectAccessToken(t *testing.T) {
}

func TestWrapper_Routes(t *testing.T) {
ctrl := gomock.NewController(t)
router := core.NewMockEchoRouter(ctrl)

router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
router.EXPECT().Use(gomock.AssignableToTypeOf(user.SessionMiddleware{}.Handle))

(&Wrapper{
storageEngine: storage.NewTestStorageEngine(t),
}).Routes(router)
t.Run("it registers handlers", func(t *testing.T) {
ctrl := gomock.NewController(t)
router := core.NewMockEchoRouter(ctrl)

router.EXPECT().Use(gomock.Any()).AnyTimes()
router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()

(&Wrapper{
storageEngine: storage.NewTestStorageEngine(t),
}).Routes(router)
})
t.Run("cache middleware URLs match registered paths", func(t *testing.T) {
ctrl := gomock.NewController(t)
router := core.NewMockEchoRouter(ctrl)

var registeredPaths []string
router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(path string, _ echo.HandlerFunc, _ ...echo.MiddlewareFunc) *echo.Route {
registeredPaths = append(registeredPaths, path)
return nil
}).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(path string, _ echo.HandlerFunc, _ ...echo.MiddlewareFunc) *echo.Route {
registeredPaths = append(registeredPaths, path)
return nil
}).AnyTimes()
router.EXPECT().Use(gomock.Any()).AnyTimes()
(&Wrapper{
storageEngine: storage.NewTestStorageEngine(t),
}).Routes(router)

// Check that all cache-control max-age paths are actual paths
for _, path := range cacheControlMaxAgeURLs {
assert.Contains(t, registeredPaths, path)
}
// Check that all cache-control no-cache paths are actual paths
for _, path := range cacheControlNoCacheURLs {
assert.Contains(t, registeredPaths, path)
}
})
}

func TestWrapper_middleware(t *testing.T) {
Expand Down
76 changes: 76 additions & 0 deletions http/cache/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (C) 2024 Nuts community
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/

package cache

import (
"fmt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"time"
)

// Middleware is a middleware that sets the Cache-Control header (no-cache or max-age) for the given request URLs.
// Use MaxAge or NoCache to create a new instance.
type Middleware struct {
Skipper middleware.Skipper
maxAge time.Duration
}

func (m Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !m.Skipper(c) {
if m.maxAge == -1 {
c.Response().Header().Set("Cache-Control", "no-cache")
// Pragma is deprecated (HTTP/1.0) but it's specified by OAuth2 RFC6749,
// so specify it for compliance.
c.Response().Header().Set("Pragma", "no-store")
} else if m.maxAge > 0 {
c.Response().Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(m.maxAge.Seconds())))
}
}
return next(c)
}
}

// MaxAge creates a new middleware that sets the Cache-Control header to the given max-age for the given request URLs.
func MaxAge(maxAge time.Duration, requestURLs ...string) Middleware {
return Middleware{
Skipper: matchRequestPathSkipper(requestURLs),
maxAge: maxAge,
}
}

// NoCache creates a new middleware that sets the Cache-Control header to no-cache for the given request URLs.
func NoCache(requestURLs ...string) Middleware {
return Middleware{
Skipper: matchRequestPathSkipper(requestURLs),
maxAge: -1,
}
}

func matchRequestPathSkipper(requestURLs []string) func(c echo.Context) bool {
return func(c echo.Context) bool {
for _, curr := range requestURLs {
if c.Request().URL.Path == curr {
return false
}
}
return true
}
}
84 changes: 84 additions & 0 deletions http/cache/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (C) 2024 Nuts community
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/

package cache

import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/require"
"net/http/httptest"
"testing"
"time"
)

func TestMaxAge(t *testing.T) {
t.Run("match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/a", nil), httpResponse)

err := MaxAge(time.Minute, "/a", "/b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Equal(t, "max-age=60", httpResponse.Header().Get("Cache-Control"))
})
t.Run("no match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse)

err := MaxAge(time.Minute, "/a", "/b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Empty(t, httpResponse.Header().Get("Cache-Control"))
})

}

func TestNoCache(t *testing.T) {
t.Run("match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/a", nil), httpResponse)

err := NoCache("/a", "/b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Equal(t, "no-cache", httpResponse.Header().Get("Cache-Control"))
require.Equal(t, "no-store", httpResponse.Header().Get("Pragma"))
})
t.Run("no match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse)

err := NoCache("/a", "/b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Empty(t, httpResponse.Header().Get("Cache-Control"))
require.Empty(t, httpResponse.Header().Get("Pragma"))
})
}

0 comments on commit 050e189

Please sign in to comment.