diff --git a/filters/auth/auth.go b/filters/auth/auth.go index 5d818b9637..e1b42ba50f 100644 --- a/filters/auth/auth.go +++ b/filters/auth/auth.go @@ -174,16 +174,3 @@ func all(left, right []string) bool { } return true } - -// intersect checks that one string in the left is also in the right -func intersect(left, right []string) bool { - for _, l := range left { - for _, r := range right { - if l == r { - return true - } - } - } - - return false -} diff --git a/filters/auth/oidc.go b/filters/auth/oidc.go index 57c1ef886e..0d08c50d18 100644 --- a/filters/auth/oidc.go +++ b/filters/auth/oidc.go @@ -329,12 +329,12 @@ func (f *tokenOidcFilter) validateAnyClaims(h map[string]interface{}) bool { return false } - keys := make([]string, 0, len(h)) - for k := range h { - keys = append(keys, k) + for _, c := range f.claims { + if _, ok := h[c]; ok { + return true + } } - - return intersect(f.claims, keys) + return false } func (f *tokenOidcFilter) validateAllClaims(h map[string]interface{}) bool { @@ -346,11 +346,12 @@ func (f *tokenOidcFilter) validateAllClaims(h map[string]interface{}) bool { return false } - keys := make([]string, 0, len(h)) - for k := range h { - keys = append(keys, k) + for _, c := range f.claims { + if _, ok := h[c]; !ok { + return false + } } - return all(f.claims, keys) + return true } type OauthState struct { diff --git a/filters/auth/tokeninfo.go b/filters/auth/tokeninfo.go index b7626ca5a8..ee68e0e281 100644 --- a/filters/auth/tokeninfo.go +++ b/filters/auth/tokeninfo.go @@ -271,16 +271,13 @@ func (f *tokeninfoFilter) validateAnyScopes(h map[string]interface{}) bool { if !ok { return false } - var a []string - for i := range v { - s, ok := v[i].(string) - if !ok { - return false + + for _, scope := range f.scopes { + if contains(v, scope) { + return true } - a = append(a, s) } - - return intersect(f.scopes, a) + return false } func (f *tokeninfoFilter) validateAllScopes(h map[string]interface{}) bool { @@ -296,16 +293,13 @@ func (f *tokeninfoFilter) validateAllScopes(h map[string]interface{}) bool { if !ok { return false } - var a []string - for i := range v { - s, ok := v[i].(string) - if !ok { + + for _, scope := range f.scopes { + if !contains(v, scope) { return false } - a = append(a, s) } - - return all(f.scopes, a) + return true } func (f *tokeninfoFilter) validateAnyKV(h map[string]interface{}) bool { @@ -336,6 +330,15 @@ func (f *tokeninfoFilter) validateAllKV(h map[string]interface{}) bool { return true } +func contains(vals []interface{}, s string) bool { + for _, v := range vals { + if v == s { + return true + } + } + return false +} + // Request handles authentication based on the defined auth type. func (f *tokeninfoFilter) Request(ctx filters.FilterContext) { r := ctx.Request() diff --git a/filters/auth/tokeninfo_test.go b/filters/auth/tokeninfo_test.go index 2f402c331c..ceb5dbbb9d 100644 --- a/filters/auth/tokeninfo_test.go +++ b/filters/auth/tokeninfo_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/zalando/skipper/eskip" "github.com/zalando/skipper/filters" + "github.com/zalando/skipper/filters/filtertest" "github.com/zalando/skipper/proxy/proxytest" ) @@ -463,7 +464,7 @@ func TestOAuth2Tokeninfo5xx(t *testing.T) { require.Equal(t, http.StatusUnauthorized, rsp.StatusCode, "auth filter failed got=%d, expected=%d, route=%s", rsp.StatusCode, http.StatusUnauthorized, r) } -func BenchmarkOAuthTokeninfoFilter(b *testing.B) { +func BenchmarkOAuthTokeninfoCreateFilter(b *testing.B) { for i := 0; i < b.N; i++ { var spec filters.Spec args := []interface{}{"uid"} @@ -475,3 +476,102 @@ func BenchmarkOAuthTokeninfoFilter(b *testing.B) { } } } + +func BenchmarkOAuthTokeninfoRequest(b *testing.B) { + b.Run("oauthTokeninfoAllScope", func(b *testing.B) { + spec := NewOAuthTokeninfoAllScope("https://127.0.0.1:12345/token", 3*time.Second) + f, err := spec.CreateFilter([]interface{}{"foobar.read", "foobar.write"}) + require.NoError(b, err) + + ctx := &filtertest.Context{ + FStateBag: map[string]interface{}{ + tokeninfoCacheKey: map[string]interface{}{ + scopeKey: []interface{}{"uid", "foobar.read", "foobar.write"}, + }, + }, + FResponse: &http.Response{}, + } + + f.Request(ctx) + require.False(b, ctx.FServed) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + f.Request(ctx) + } + }) + + b.Run("oauthTokeninfoAnyScope", func(b *testing.B) { + spec := NewOAuthTokeninfoAnyScope("https://127.0.0.1:12345/token", 3*time.Second) + f, err := spec.CreateFilter([]interface{}{"foobar.read", "foobar.write"}) + require.NoError(b, err) + + ctx := &filtertest.Context{ + FStateBag: map[string]interface{}{ + tokeninfoCacheKey: map[string]interface{}{ + scopeKey: []interface{}{"uid", "foobar.write", "foobar.exec"}, + }, + }, + FResponse: &http.Response{}, + } + + f.Request(ctx) + require.False(b, ctx.FServed) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + f.Request(ctx) + } + }) +} + +func TestOAuthTokeninfoAllocs(t *testing.T) { + tio := TokeninfoOptions{ + URL: "https://127.0.0.1:12345/token", + Timeout: 3 * time.Second, + } + + fr := make(filters.Registry) + fr.Register(NewOAuthTokeninfoAllScopeWithOptions(tio)) + fr.Register(NewOAuthTokeninfoAnyScopeWithOptions(tio)) + fr.Register(NewOAuthTokeninfoAllKVWithOptions(tio)) + fr.Register(NewOAuthTokeninfoAnyKVWithOptions(tio)) + + var filters []filters.Filter + for _, def := range eskip.MustParseFilters(` + oauthTokeninfoAnyScope("foobar.read", "foobar.write") -> + oauthTokeninfoAllScope("foobar.read", "foobar.write") -> + oauthTokeninfoAnyKV("k1", "v1", "k2", "v2") -> + oauthTokeninfoAllKV("k1", "v1", "k2", "v2") + `) { + f, err := fr[def.Name].CreateFilter(def.Args) + require.NoError(t, err) + + filters = append(filters, f) + } + + ctx := &filtertest.Context{ + FStateBag: map[string]interface{}{ + tokeninfoCacheKey: map[string]interface{}{ + scopeKey: []interface{}{"uid", "foobar.read", "foobar.write", "foobar.exec"}, + "k1": "v1", + "k2": "v2", + }, + }, + FResponse: &http.Response{}, + } + + allocs := testing.AllocsPerRun(100, func() { + for _, f := range filters { + f.Request(ctx) + } + require.False(t, ctx.FServed) + }) + if allocs != 0.0 { + t.Errorf("Expected zero allocations, got %f", allocs) + } +}