From 38a233d3a3d3b97b10bc64353f61021cd85e6a5f Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Thu, 10 Aug 2023 11:05:16 -0700 Subject: [PATCH] Introduce pre redirect hook plugin during auth callback (#601) * Adding a predredirect hook plugin Signed-off-by: pmahindrakar-oss * Add test logs Signed-off-by: pmahindrakar-oss * test logs Signed-off-by: pmahindrakar-oss * fix Signed-off-by: pmahindrakar-oss * Reading identity token for getting subject Signed-off-by: pmahindrakar-oss * reverting Signed-off-by: pmahindrakar-oss * Adding PreRedirectHookError Signed-off-by: pmahindrakar-oss * Add some more tests Signed-off-by: pmahindrakar-oss * lint fixes Signed-off-by: pmahindrakar-oss * removed log line Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 50 ++++++++++++++++++++++++++------ auth/handlers_test.go | 61 +++++++++++++++++++++++++++++++++------- pkg/server/service.go | 53 +++++++++++++++++----------------- plugins/registry.go | 1 + plugins/registry_test.go | 20 +++++++++++++ 5 files changed, 139 insertions(+), 46 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index 4133af4e7..9604d90ec 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,11 +8,6 @@ import ( "strings" "time" - "github.com/flyteorg/flyteadmin/auth/interfaces" - "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "golang.org/x/oauth2" "google.golang.org/grpc" @@ -21,6 +16,13 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "google.golang.org/protobuf/runtime/protoiface" + + "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/plugins" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" ) const ( @@ -29,6 +31,23 @@ const ( FromHTTPVal = "true" ) +type PreRedirectHookError struct { + Message string + Code int +} + +func (e *PreRedirectHookError) Error() string { + return e.Message +} + +// PreRedirectHookFunc Interface used for running custom code before the redirect happens during a successful auth flow. +// This might be useful in cases where the auth flow allows the user to login since the IDP has been configured +// for eg: to allow all users from a particular domain to login +// but you want to restrict access to only a particular set of user ids. eg : users@domain.com are allowed to login but user user1@domain.com, user2@domain.com +// should only be allowed +// PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error +// without which the current usage in GetCallbackHandler will set this to InternalServerError +type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error @@ -39,11 +58,11 @@ type AuthenticatedClientMeta struct { Subject string } -func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext) { +func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) { // Add HTTP handlers for OAuth2 endpoints handler.HandleFunc("/login", RefreshTokensIfExists(ctx, authCtx, GetLoginHandler(ctx, authCtx))) - handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx)) + handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx, pluginRegistry)) // The metadata endpoint is an RFC-defined constant, but we need a leading / for the handler to pattern match correctly. handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx)) @@ -129,14 +148,13 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte logger.Errorf(ctx, "Was not able to create a redirect cookie") } } - http.Redirect(writer, request, url, http.StatusTemporaryRedirect) } } // GetCallbackHandler returns a handler that is called by the OIdC provider with the authorization code to complete // the user authentication flow. -func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc { +func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { logger.Debugf(ctx, "Running callback handler... for RequestURI %v", request.RequestURI) authorizationCode := request.FormValue(AuthorizationResponseCodeType) @@ -178,6 +196,20 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) + if preRedirectHook != nil { + logger.Infof(ctx, "preRedirect hook is set") + if err := preRedirectHook(ctx, authCtx, request, writer); err != nil { + logger.Errorf(ctx, "failed the preRedirect hook due %v with status code %v", err.Message, err.Code) + if http.StatusText(err.Code) != "" { + writer.WriteHeader(err.Code) + } else { + writer.WriteHeader(http.StatusInternalServerError) + } + return + } + logger.Info(ctx, "Successfully called the preRedirect hook") + } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } diff --git a/auth/handlers_test.go b/auth/handlers_test.go index 88232de1c..449b13c4a 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -10,18 +10,19 @@ import ( "strings" "testing" + "github.com/coreos/go-oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyteadmin/auth/config" + "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/auth/interfaces/mocks" "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/plugins" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" stdConfig "github.com/flyteorg/flytestdlib/config" - - "github.com/coreos/go-oidc" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/oauth2" ) const ( @@ -81,7 +82,8 @@ func TestGetCallbackHandlerWithErrorOnToken(t *testing.T) { defer localServer.Close() http.DefaultClient = localServer.Client() mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -102,7 +104,8 @@ func TestGetCallbackHandlerWithUnAuthorized(t *testing.T) { defer localServer.Close() http.DefaultClient = localServer.Client() mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) writer := httptest.NewRecorder() callbackHandlerFunc(writer, request) @@ -153,7 +156,8 @@ func TestGetCallbackHandler(t *testing.T) { t.Run("forbidden request when accessing user info", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -172,9 +176,15 @@ func TestGetCallbackHandler(t *testing.T) { assert.Equal(t, "403 Forbidden", writer.Result().Status) }) - t.Run("successful callback and redirect", func(t *testing.T) { + t.Run("successful callback with redirect and successful preredirect hook call", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + return nil + } + + r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc) + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -193,6 +203,37 @@ func TestGetCallbackHandler(t *testing.T) { callbackHandlerFunc(writer, request) assert.Equal(t, "307 Temporary Redirect", writer.Result().Status) }) + + t.Run("successful callback with pre-redirecthook failure", func(t *testing.T) { + mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) + r := plugins.NewRegistry() + var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + return &PreRedirectHookError{ + Code: http.StatusPreconditionFailed, + Message: "precondition error", + } + } + + r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc) + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) + request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) + addCsrfCookie(request) + addStateString(request) + writer := httptest.NewRecorder() + openIDConfigJSON = fmt.Sprintf(`{ + "userinfo_endpoint": "%v/userinfo", + "issuer": "%v", + "authorization_endpoint": "%v/auth", + "token_endpoint": "%v/token", + "jwks_uri": "%v/keys", + "id_token_signing_alg_values_supported": ["RS256"] + }`, issuer, issuer, issuer, issuer, issuer) + oidcProvider, err := oidc.NewProvider(ctx, issuer) + assert.Nil(t, err) + mockAuthCtx.OnOidcProviderMatch().Return(oidcProvider) + callbackHandlerFunc(writer, request) + assert.Equal(t, "412 Precondition Failed", writer.Result().Status) + }) } func TestGetLoginHandler(t *testing.T) { diff --git a/pkg/server/service.go b/pkg/server/service.go index 4a7983087..b060c3ec7 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -9,42 +9,39 @@ import ( "strings" "time" + "github.com/gorilla/handlers" + grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/reflection" "k8s.io/apimachinery/pkg/util/rand" - "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - - runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - - "github.com/flyteorg/flyteadmin/dataproxy" - "github.com/flyteorg/flyteadmin/plugins" - "github.com/flyteorg/flyteadmin/auth" "github.com/flyteorg/flyteadmin/auth/authzserver" authConfig "github.com/flyteorg/flyteadmin/auth/config" "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/dataproxy" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/config" "github.com/flyteorg/flyteadmin/pkg/rpc" "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" + runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime" runtimeIfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteadmin/plugins" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" - "github.com/gorilla/handlers" - grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "github.com/grpc-ecosystem/grpc-gateway/runtime" - "github.com/pkg/errors" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/health" - "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" ) var defaultCorsHeaders = []string{"Content-Type"} @@ -163,7 +160,7 @@ func healthCheckFunc(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) } -func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, +func newHTTPServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, additionalHandlers map[string]func(http.ResponseWriter, *http.Request), grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { @@ -191,7 +188,7 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig. if cfg.Security.UseAuth { // Add HTTP handlers for OIDC endpoints - auth.RegisterHandlers(ctx, mux, authCtx) + auth.RegisterHandlers(ctx, mux, authCtx, pluginRegistry) // Add HTTP handlers for OAuth2 endpoints authzserver.RegisterHandlers(mux, authCtx) @@ -278,7 +275,8 @@ func generateRequestID() string { func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageConfig *storage.Config, - additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + scope promutils.Scope) error { logger.Infof(ctx, "Serving Flyte Admin Insecure") // This will parse configuration and create the necessary objects for dealing with auth @@ -343,7 +341,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, grpcOptions = append(grpcOptions, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) + httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) if err != nil { return err } @@ -390,7 +388,8 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageCfg *storage.Config, - additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + scope promutils.Scope) error { certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) if err != nil { return err @@ -445,7 +444,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c serverOpts = append(serverOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) + httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) if err != nil { return err } diff --git a/plugins/registry.go b/plugins/registry.go index 3c2186326..14682f7e8 100644 --- a/plugins/registry.go +++ b/plugins/registry.go @@ -12,6 +12,7 @@ const ( PluginIDWorkflowExecutor PluginID = "WorkflowExecutor" PluginIDDataProxy PluginID = "DataProxy" PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware" + PluginIDPreRedirectHook PluginID = "PreRedirectHook" ) type AtomicRegistry struct { diff --git a/plugins/registry_test.go b/plugins/registry_test.go index 757b596fd..0737c1281 100644 --- a/plugins/registry_test.go +++ b/plugins/registry_test.go @@ -1,6 +1,8 @@ package plugins import ( + "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -21,6 +23,24 @@ func TestNewAtomicRegistry(t *testing.T) { assert.Equal(t, 5, r.Get(PluginIDDataProxy)) } +type PreRedirectHookFunc func(ctx context.Context) error + +func TestRedirectHook(t *testing.T) { + ar := NewAtomicRegistry(nil) + r := NewRegistry() + + var redirectHookfn PreRedirectHookFunc = func(ctx context.Context) error { + return fmt.Errorf("redirect hook error") + } + err := r.Register(PluginIDPreRedirectHook, redirectHookfn) + assert.NoError(t, err) + ar.Store(r) + r = ar.Load() + fn := Get[PreRedirectHookFunc](r, PluginIDPreRedirectHook) + err = fn(context.Background()) + assert.Equal(t, fmt.Errorf("redirect hook error"), err) +} + func TestRegistry_RegisterDefault(t *testing.T) { r := NewRegistry() r.RegisterDefault("hello", 5)