Skip to content

Commit

Permalink
IAM: Set Cache-Control headers for Token Endpoint (no cache) and meta…
Browse files Browse the repository at this point in the history
…data (max-age)
  • Loading branch information
reinkrul committed May 27, 2024
1 parent 7767c1e commit 6560dcc
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 71 deletions.
8 changes: 7 additions & 1 deletion auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ var cacheControlMaxAgeURLs = []string{
"/statuslist/:did/:page",
}

// cacheControlNoCacheURLs holds API endpoints that should have a no-cache cache control header set.
var cacheControlNoCacheURLs = []string{
"/oauth2/:did/token",
}

//go:embed assets
var assetsFS embed.FS

Expand Down Expand Up @@ -152,7 +157,8 @@ func (r Wrapper) Routes(router core.EchoRouter) {
return next(c)
}
}, audit.Middleware(apiModuleName))
router.Use(cache.MaxAge(5*time.Minute, cacheControlMaxAgeURLs).Handle)
router.Use(cache.MaxAge(5*time.Minute, cacheControlMaxAgeURLs...).Handle)
router.Use(cache.NoCache(cacheControlNoCacheURLs...).Handle)
}

func (r Wrapper) strictMiddleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) {
Expand Down
11 changes: 9 additions & 2 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ func TestWrapper_Routes(t *testing.T) {

(&Wrapper{}).Routes(router)
})
t.Run("middleware cache-control: max-age URLs match registered paths", func(t *testing.T) {
t.Run("cache middleware URLs match registered paths", func(t *testing.T) {
ctrl := gomock.NewController(t)
router := core.NewMockEchoRouter(ctrl)

Expand All @@ -724,14 +724,21 @@ func TestWrapper_Routes(t *testing.T) {
registeredPaths = append(registeredPaths, path)
return nil
}).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(path string, _ echo.HandlerFunc, _ ...echo.MiddlewareFunc) *echo.Route {
registeredPaths = append(registeredPaths, path)
return nil
}).AnyTimes()
router.EXPECT().Use(gomock.Any()).AnyTimes()
(&Wrapper{}).Routes(router)

// Check that all cache-control max-age paths are actual paths
for _, path := range cacheControlMaxAgeURLs {
assert.Contains(t, registeredPaths, path)
}
// Check that all cache-control no-cache paths are actual paths
for _, path := range cacheControlNoCacheURLs {
assert.Contains(t, registeredPaths, path)
}
})
}

Expand Down
14 changes: 2 additions & 12 deletions auth/api/iam/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 1 addition & 7 deletions auth/api/iam/openid4vp.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,7 @@ func (r Wrapper) handleAccessTokenRequest(ctx context.Context, request HandleTok
if err != nil {
return nil, oauthError(oauth.ServerError, fmt.Sprintf("failed to create access token: %s", err.Error()))
}
return HandleTokenRequest200JSONResponse{
Body: *response,
Headers: HandleTokenRequest200ResponseHeaders{
CacheControl: "no-cache",
Pragma: "no-cache",
},
}, nil
return HandleTokenRequest200JSONResponse(*response), nil
}

func (r Wrapper) handleCallbackError(request CallbackRequestObject) (CallbackResponseObject, error) {
Expand Down
12 changes: 4 additions & 8 deletions auth/api/iam/openid4vp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,10 @@ func Test_handleAccessTokenRequest(t *testing.T) {
require.NoError(t, err)
token, ok := response.(HandleTokenRequest200JSONResponse)
require.True(t, ok)
// Assert access token
assert.NotEmpty(t, token.Body.AccessToken)
assert.Equal(t, "DPoP", token.Body.TokenType)
assert.Equal(t, 900, *token.Body.ExpiresIn)
assert.Equal(t, "scope", *token.Body.Scope)
// Assert proper caching headers (as specified by OAuth2 RFC)
assert.Equal(t, "no-cache", token.Headers.CacheControl)
assert.Equal(t, "no-cache", token.Headers.Pragma)
assert.NotEmpty(t, token.AccessToken)
assert.Equal(t, "DPoP", token.TokenType)
assert.Equal(t, 900, *token.ExpiresIn)
assert.Equal(t, "scope", *token.Scope)
// authz code is burned
assert.ErrorIs(t, ctx.client.oauthCodeStore().Get(code, new(OAuthSession)), storage.ErrNotFound)
})
Expand Down
8 changes: 1 addition & 7 deletions auth/api/iam/s2s_vptoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,7 @@ func (r Wrapper) handleS2SAccessTokenRequest(ctx context.Context, issuer did.DID
if err != nil {
return nil, err
}
return HandleTokenRequest200JSONResponse{
Body: *response,
Headers: HandleTokenRequest200ResponseHeaders{
CacheControl: "no-cache",
Pragma: "no-cache",
},
}, nil
return HandleTokenRequest200JSONResponse(*response), nil
}

func resolveInputDescriptorValues(presentationDefinitions pe.WalletOwnerMapping, credentialMap map[string]vc.VerifiableCredential) (map[string]any, error) {
Expand Down
24 changes: 10 additions & 14 deletions auth/api/iam/s2s_vptoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,11 @@ func TestWrapper_handleS2SAccessTokenRequest(t *testing.T) {

require.NoError(t, err)
require.IsType(t, HandleTokenRequest200JSONResponse{}, resp)
// Assert access token
tokenResponse := resp.(HandleTokenRequest200JSONResponse)
assert.Equal(t, "DPoP", tokenResponse.Body.TokenType)
assert.Equal(t, requestedScope, *tokenResponse.Body.Scope)
assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.Body.ExpiresIn)
assert.NotEmpty(t, tokenResponse.Body.AccessToken)
// Assert caching headers
assert.Equal(t, "no-cache", tokenResponse.Headers.CacheControl)
assert.Equal(t, "no-cache", tokenResponse.Headers.Pragma)
tokenResponse := TokenResponse(resp.(HandleTokenRequest200JSONResponse))
assert.Equal(t, "DPoP", tokenResponse.TokenType)
assert.Equal(t, requestedScope, *tokenResponse.Scope)
assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.ExpiresIn)
assert.NotEmpty(t, tokenResponse.AccessToken)
})
t.Run("missing presentation expiry date", func(t *testing.T) {
ctx := newTestClient(t)
Expand Down Expand Up @@ -177,11 +173,11 @@ func TestWrapper_handleS2SAccessTokenRequest(t *testing.T) {

require.NoError(t, err)
require.IsType(t, HandleTokenRequest200JSONResponse{}, resp)
tokenResponse := resp.(HandleTokenRequest200JSONResponse)
assert.Equal(t, "DPoP", tokenResponse.Body.TokenType)
assert.Equal(t, requestedScope, *tokenResponse.Body.Scope)
assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.Body.ExpiresIn)
assert.NotEmpty(t, tokenResponse.Body.AccessToken)
tokenResponse := TokenResponse(resp.(HandleTokenRequest200JSONResponse))
assert.Equal(t, "DPoP", tokenResponse.TokenType)
assert.Equal(t, requestedScope, *tokenResponse.Scope)
assert.Equal(t, int(accessTokenValidity.Seconds()), *tokenResponse.ExpiresIn)
assert.NotEmpty(t, tokenResponse.AccessToken)
})
t.Run("VP is not valid JSON", func(t *testing.T) {
ctx := newTestClient(t)
Expand Down
7 changes: 0 additions & 7 deletions docs/_static/auth/iam.partial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ paths:
responses:
"200":
description: OK
headers:
Cache-Control:
schema:
type: string
Pragma:
schema:
type: string
content:
application/json:
schema:
Expand Down
39 changes: 28 additions & 11 deletions http/cache/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ type Middleware struct {
func (m Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !m.Skipper(c) {
if m.maxAge > 0 {
if m.maxAge == -1 {
c.Response().Header().Set("Cache-Control", "no-cache")
// Pragma is deprecated (HTTP/1.0) but it's specified by OAuth2 RFC6749,
// so specify it for compliance.
c.Response().Header().Set("Pragma", "no-store")
} else if m.maxAge > 0 {
c.Response().Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(m.maxAge.Seconds())))
}
}
Expand All @@ -25,17 +30,29 @@ func (m Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc {
}

// MaxAge creates a new middleware that sets the Cache-Control header to the given max-age for the given request URLs.
func MaxAge(maxAge time.Duration, requestURLs []string) Middleware {
func MaxAge(maxAge time.Duration, requestURLs ...string) Middleware {
return Middleware{
Skipper: func(c echo.Context) bool {
for _, curr := range requestURLs {
// trim leading and trailing /before comparing, just in case
if strings.Trim(c.Request().URL.Path, "/") == strings.Trim(curr, "/") {
return false
}
Skipper: matchRequestPathSkipper(requestURLs),
maxAge: maxAge,
}
}

// NoCache creates a new middleware that sets the Cache-Control header to no-cache for the given request URLs.
func NoCache(requestURLs ...string) Middleware {
return Middleware{
Skipper: matchRequestPathSkipper(requestURLs),
maxAge: -1,
}
}

func matchRequestPathSkipper(requestURLs []string) func(c echo.Context) bool {
return func(c echo.Context) bool {
for _, curr := range requestURLs {
// trim leading and trailing /before comparing, just in case
if strings.Trim(c.Request().URL.Path, "/") == strings.Trim(curr, "/") {
return false
}
return true
},
maxAge: maxAge,
}
return true
}
}
33 changes: 31 additions & 2 deletions http/cache/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestMaxAge(t *testing.T) {
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/a/", nil), httpResponse)

err := MaxAge(time.Minute, []string{"a", "b"}).Handle(func(c echo.Context) error {
err := MaxAge(time.Minute, "a", "b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

Expand All @@ -26,7 +26,7 @@ func TestMaxAge(t *testing.T) {
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse)

err := MaxAge(time.Minute, []string{"a", "b"}).Handle(func(c echo.Context) error {
err := MaxAge(time.Minute, "a", "b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

Expand All @@ -35,3 +35,32 @@ func TestMaxAge(t *testing.T) {
})

}

func TestNoCache(t *testing.T) {
t.Run("match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/a/", nil), httpResponse)

err := NoCache("a", "b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Equal(t, "no-cache", httpResponse.Header().Get("Cache-Control"))
require.Equal(t, "no-store", httpResponse.Header().Get("Pragma"))
})
t.Run("no match", func(t *testing.T) {
e := echo.New()
httpResponse := httptest.NewRecorder()
echoContext := e.NewContext(httptest.NewRequest("GET", "/c", nil), httpResponse)

err := NoCache("a", "b").Handle(func(c echo.Context) error {
return c.String(200, "OK")
})(echoContext)

require.NoError(t, err)
require.Empty(t, httpResponse.Header().Get("Cache-Control"))
require.Empty(t, httpResponse.Header().Get("Pragma"))
})
}

0 comments on commit 6560dcc

Please sign in to comment.