From 6560dcc7eac8a8297503b83e56d9798ca4ad2053 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Mon, 27 May 2024 14:12:08 +0200 Subject: [PATCH] IAM: Set Cache-Control headers for Token Endpoint (no cache) and metadata (max-age) --- auth/api/iam/api.go | 8 +++++- auth/api/iam/api_test.go | 11 +++++++-- auth/api/iam/generated.go | 14 ++--------- auth/api/iam/openid4vp.go | 8 +----- auth/api/iam/openid4vp_test.go | 12 +++------ auth/api/iam/s2s_vptoken.go | 8 +----- auth/api/iam/s2s_vptoken_test.go | 24 ++++++++---------- docs/_static/auth/iam.partial.yaml | 7 ------ http/cache/middleware.go | 39 +++++++++++++++++++++--------- http/cache/middleware_test.go | 33 +++++++++++++++++++++++-- 10 files changed, 93 insertions(+), 71 deletions(-) diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index c33aec433f..84189869f8 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -93,6 +93,11 @@ var cacheControlMaxAgeURLs = []string{ "/statuslist/:did/:page", } +// cacheControlNoCacheURLs holds API endpoints that should have a no-cache cache control header set. +var cacheControlNoCacheURLs = []string{ + "/oauth2/:did/token", +} + //go:embed assets var assetsFS embed.FS @@ -152,7 +157,8 @@ func (r Wrapper) Routes(router core.EchoRouter) { return next(c) } }, audit.Middleware(apiModuleName)) - router.Use(cache.MaxAge(5*time.Minute, cacheControlMaxAgeURLs).Handle) + router.Use(cache.MaxAge(5*time.Minute, cacheControlMaxAgeURLs...).Handle) + router.Use(cache.NoCache(cacheControlNoCacheURLs...).Handle) } func (r Wrapper) strictMiddleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) { diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index 5732d1850c..024f4d09dd 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -715,7 +715,7 @@ func TestWrapper_Routes(t *testing.T) { (&Wrapper{}).Routes(router) }) - t.Run("middleware cache-control: max-age URLs match registered paths", func(t *testing.T) { + t.Run("cache middleware URLs match registered paths", func(t *testing.T) { ctrl := gomock.NewController(t) router := core.NewMockEchoRouter(ctrl) @@ -724,7 +724,10 @@ func TestWrapper_Routes(t *testing.T) { registeredPaths = append(registeredPaths, path) return nil }).AnyTimes() - router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(path string, _ echo.HandlerFunc, _ ...echo.MiddlewareFunc) *echo.Route { + registeredPaths = append(registeredPaths, path) + return nil + }).AnyTimes() router.EXPECT().Use(gomock.Any()).AnyTimes() (&Wrapper{}).Routes(router) @@ -732,6 +735,10 @@ func TestWrapper_Routes(t *testing.T) { for _, path := range cacheControlMaxAgeURLs { assert.Contains(t, registeredPaths, path) } + // Check that all cache-control no-cache paths are actual paths + for _, path := range cacheControlNoCacheURLs { + assert.Contains(t, registeredPaths, path) + } }) } diff --git a/auth/api/iam/generated.go b/auth/api/iam/generated.go index c312c53201..73de8bdb0d 100644 --- a/auth/api/iam/generated.go +++ b/auth/api/iam/generated.go @@ -1772,23 +1772,13 @@ type HandleTokenRequestResponseObject interface { VisitHandleTokenRequestResponse(w http.ResponseWriter) error } -type HandleTokenRequest200ResponseHeaders struct { - CacheControl string - Pragma string -} - -type HandleTokenRequest200JSONResponse struct { - Body TokenResponse - Headers HandleTokenRequest200ResponseHeaders -} +type HandleTokenRequest200JSONResponse TokenResponse func (response HandleTokenRequest200JSONResponse) VisitHandleTokenRequestResponse(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/json") - w.Header().Set("Cache-Control", fmt.Sprint(response.Headers.CacheControl)) - w.Header().Set("Pragma", fmt.Sprint(response.Headers.Pragma)) w.WriteHeader(200) - return json.NewEncoder(w).Encode(response.Body) + return json.NewEncoder(w).Encode(response) } type HandleTokenRequestdefaultJSONResponse struct { diff --git a/auth/api/iam/openid4vp.go b/auth/api/iam/openid4vp.go index a700c1b546..14f2470a9f 100644 --- a/auth/api/iam/openid4vp.go +++ b/auth/api/iam/openid4vp.go @@ -742,13 +742,7 @@ func (r Wrapper) handleAccessTokenRequest(ctx context.Context, request HandleTok if err != nil { return nil, oauthError(oauth.ServerError, fmt.Sprintf("failed to create access token: %s", err.Error())) } - return HandleTokenRequest200JSONResponse{ - Body: *response, - Headers: HandleTokenRequest200ResponseHeaders{ - CacheControl: "no-cache", - Pragma: "no-cache", - }, - }, nil + return HandleTokenRequest200JSONResponse(*response), nil } func (r Wrapper) handleCallbackError(request CallbackRequestObject) (CallbackResponseObject, error) { diff --git a/auth/api/iam/openid4vp_test.go b/auth/api/iam/openid4vp_test.go index 0c9b9c08c2..b500dbd413 100644 --- a/auth/api/iam/openid4vp_test.go +++ b/auth/api/iam/openid4vp_test.go @@ -680,14 +680,10 @@ func Test_handleAccessTokenRequest(t *testing.T) { require.NoError(t, err) token, ok := response.(HandleTokenRequest200JSONResponse) require.True(t, ok) - // Assert access token - assert.NotEmpty(t, token.Body.AccessToken) - assert.Equal(t, "DPoP", token.Body.TokenType) - assert.Equal(t, 900, *token.Body.ExpiresIn) - assert.Equal(t, "scope", *token.Body.Scope) - // Assert proper caching headers (as specified by OAuth2 RFC) - assert.Equal(t, "no-cache", token.Headers.CacheControl) - assert.Equal(t, "no-cache", token.Headers.Pragma) + assert.NotEmpty(t, token.AccessToken) + assert.Equal(t, "DPoP", token.TokenType) + assert.Equal(t, 900, *token.ExpiresIn) + assert.Equal(t, "scope", *token.Scope) // authz code is burned assert.ErrorIs(t, ctx.client.oauthCodeStore().Get(code, new(OAuthSession)), storage.ErrNotFound) }) diff --git a/auth/api/iam/s2s_vptoken.go b/auth/api/iam/s2s_vptoken.go index d7ebfc2aee..d0fff2f01e 100644 --- a/auth/api/iam/s2s_vptoken.go +++ b/auth/api/iam/s2s_vptoken.go @@ -113,13 +113,7 @@ func (r Wrapper) handleS2SAccessTokenRequest(ctx context.Context, issuer did.DID if err != nil { return nil, err } - return HandleTokenRequest200JSONResponse{ - Body: *response, - Headers: HandleTokenRequest200ResponseHeaders{ - CacheControl: "no-cache", - Pragma: "no-cache", - }, - }, nil + return HandleTokenRequest200JSONResponse(*response), nil } func resolveInputDescriptorValues(presentationDefinitions pe.WalletOwnerMapping, credentialMap map[string]vc.VerifiableCredential) (map[string]any, error) { diff --git a/auth/api/iam/s2s_vptoken_test.go b/auth/api/iam/s2s_vptoken_test.go index cb721908ea..847034ad16 100644 --- a/auth/api/iam/s2s_vptoken_test.go +++ b/auth/api/iam/s2s_vptoken_test.go @@ -125,15 +125,11 @@ func TestWrapper_handleS2SAccessTokenRequest(t *testing.T) { require.NoError(t, err) require.IsType(t, HandleTokenRequest200JSONResponse{}, resp) - // Assert access token - tokenResponse := resp.(HandleTokenRequest200JSONResponse) - assert.Equal(t, "DPoP", tokenResponse.Body.TokenType) - assert.Equal(t, requestedScope, *tokenResponse.Body.Scope) - assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.Body.ExpiresIn) - assert.NotEmpty(t, tokenResponse.Body.AccessToken) - // Assert caching headers - assert.Equal(t, "no-cache", tokenResponse.Headers.CacheControl) - assert.Equal(t, "no-cache", tokenResponse.Headers.Pragma) + tokenResponse := TokenResponse(resp.(HandleTokenRequest200JSONResponse)) + assert.Equal(t, "DPoP", tokenResponse.TokenType) + assert.Equal(t, requestedScope, *tokenResponse.Scope) + assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.ExpiresIn) + assert.NotEmpty(t, tokenResponse.AccessToken) }) t.Run("missing presentation expiry date", func(t *testing.T) { ctx := newTestClient(t) @@ -177,11 +173,11 @@ func TestWrapper_handleS2SAccessTokenRequest(t *testing.T) { require.NoError(t, err) require.IsType(t, HandleTokenRequest200JSONResponse{}, resp) - tokenResponse := resp.(HandleTokenRequest200JSONResponse) - assert.Equal(t, "DPoP", tokenResponse.Body.TokenType) - assert.Equal(t, requestedScope, *tokenResponse.Body.Scope) - assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.Body.ExpiresIn) - assert.NotEmpty(t, tokenResponse.Body.AccessToken) + tokenResponse := TokenResponse(resp.(HandleTokenRequest200JSONResponse)) + assert.Equal(t, "DPoP", tokenResponse.TokenType) + assert.Equal(t, requestedScope, *tokenResponse.Scope) + assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.ExpiresIn) + assert.NotEmpty(t, tokenResponse.AccessToken) }) t.Run("VP is not valid JSON", func(t *testing.T) { ctx := newTestClient(t) diff --git a/docs/_static/auth/iam.partial.yaml b/docs/_static/auth/iam.partial.yaml index 6920157ff8..92c6b36646 100644 --- a/docs/_static/auth/iam.partial.yaml +++ b/docs/_static/auth/iam.partial.yaml @@ -87,13 +87,6 @@ paths: responses: "200": description: OK - headers: - Cache-Control: - schema: - type: string - Pragma: - schema: - type: string content: application/json: schema: diff --git a/http/cache/middleware.go b/http/cache/middleware.go index b8d7752af1..3c78f013ce 100644 --- a/http/cache/middleware.go +++ b/http/cache/middleware.go @@ -16,7 +16,12 @@ type Middleware struct { func (m Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if !m.Skipper(c) { - if m.maxAge > 0 { + if m.maxAge == -1 { + c.Response().Header().Set("Cache-Control", "no-cache") + // Pragma is deprecated (HTTP/1.0) but it's specified by OAuth2 RFC6749, + // so specify it for compliance. + c.Response().Header().Set("Pragma", "no-store") + } else if m.maxAge > 0 { c.Response().Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(m.maxAge.Seconds()))) } } @@ -25,17 +30,29 @@ func (m Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc { } // MaxAge creates a new middleware that sets the Cache-Control header to the given max-age for the given request URLs. -func MaxAge(maxAge time.Duration, requestURLs []string) Middleware { +func MaxAge(maxAge time.Duration, requestURLs ...string) Middleware { return Middleware{ - Skipper: func(c echo.Context) bool { - for _, curr := range requestURLs { - // trim leading and trailing /before comparing, just in case - if strings.Trim(c.Request().URL.Path, "/") == strings.Trim(curr, "/") { - return false - } + Skipper: matchRequestPathSkipper(requestURLs), + maxAge: maxAge, + } +} + +// NoCache creates a new middleware that sets the Cache-Control header to no-cache for the given request URLs. +func NoCache(requestURLs ...string) Middleware { + return Middleware{ + Skipper: matchRequestPathSkipper(requestURLs), + maxAge: -1, + } +} + +func matchRequestPathSkipper(requestURLs []string) func(c echo.Context) bool { + return func(c echo.Context) bool { + for _, curr := range requestURLs { + // trim leading and trailing /before comparing, just in case + if strings.Trim(c.Request().URL.Path, "/") == strings.Trim(curr, "/") { + return false } - return true - }, - maxAge: maxAge, + } + return true } } diff --git a/http/cache/middleware_test.go b/http/cache/middleware_test.go index f156412bce..59a6c73ab4 100644 --- a/http/cache/middleware_test.go +++ b/http/cache/middleware_test.go @@ -14,7 +14,7 @@ func TestMaxAge(t *testing.T) { httpResponse := httptest.NewRecorder() echoContext := e.NewContext(httptest.NewRequest("GET", "/a/", nil), httpResponse) - err := MaxAge(time.Minute, []string{"a", "b"}).Handle(func(c echo.Context) error { + err := MaxAge(time.Minute, "a", "b").Handle(func(c echo.Context) error { return c.String(200, "OK") })(echoContext) @@ -26,7 +26,7 @@ func TestMaxAge(t *testing.T) { httpResponse := httptest.NewRecorder() echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse) - err := MaxAge(time.Minute, []string{"a", "b"}).Handle(func(c echo.Context) error { + err := MaxAge(time.Minute, "a", "b").Handle(func(c echo.Context) error { return c.String(200, "OK") })(echoContext) @@ -35,3 +35,32 @@ func TestMaxAge(t *testing.T) { }) } + +func TestNoCache(t *testing.T) { + t.Run("match", func(t *testing.T) { + e := echo.New() + httpResponse := httptest.NewRecorder() + echoContext := e.NewContext(httptest.NewRequest("GET", "/a/", nil), httpResponse) + + err := NoCache("a", "b").Handle(func(c echo.Context) error { + return c.String(200, "OK") + })(echoContext) + + require.NoError(t, err) + require.Equal(t, "no-cache", httpResponse.Header().Get("Cache-Control")) + require.Equal(t, "no-store", httpResponse.Header().Get("Pragma")) + }) + t.Run("no match", func(t *testing.T) { + e := echo.New() + httpResponse := httptest.NewRecorder() + echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse) + + err := NoCache("a", "b").Handle(func(c echo.Context) error { + return c.String(200, "OK") + })(echoContext) + + require.NoError(t, err) + require.Empty(t, httpResponse.Header().Get("Cache-Control")) + require.Empty(t, httpResponse.Header().Get("Pragma")) + }) +}