diff --git a/api/api.go b/api/api.go index fc19c99e4fe..3a6d7349712 100644 --- a/api/api.go +++ b/api/api.go @@ -45,6 +45,7 @@ import ( "www.velocidex.com/golang/velociraptor/api/authenticators" api_proto "www.velocidex.com/golang/velociraptor/api/proto" "www.velocidex.com/golang/velociraptor/api/tables" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" artifacts_proto "www.velocidex.com/golang/velociraptor/artifacts/proto" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/file_store/api" @@ -1217,7 +1218,7 @@ func StartMonitoringService( config_obj.Monitoring.BindAddress, config_obj.Monitoring.BindPort) - mux := http.NewServeMux() + mux := api_utils.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) server := &http.Server{ Addr: bind_addr, diff --git a/api/assets.go b/api/assets.go index d1ebbd7eab0..3b7fe188fa0 100644 --- a/api/assets.go +++ b/api/assets.go @@ -32,6 +32,7 @@ import ( "github.com/lpar/gzipped" context "golang.org/x/net/context" "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/gui/velociraptor" @@ -42,10 +43,10 @@ import ( func install_static_assets( ctx context.Context, - config_obj *config_proto.Config, mux *http.ServeMux) { + config_obj *config_proto.Config, mux *api_utils.ServeMux) { base := utils.GetBasePath(config_obj) dir := utils.Join(base, "/app/") - mux.Handle(dir, ipFilter(config_obj, http.StripPrefix( + mux.Handle(dir, ipFilter(config_obj, api_utils.StripPrefix( dir, fixCSSURLs(config_obj, gzipped.FileServer(NewCachedFilesystem(ctx, gui_assets.HTTP)))))) @@ -75,35 +76,36 @@ func GetTemplateHandler( return nil, err } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userinfo := GetUserInfo(r.Context(), config_obj) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + userinfo := GetUserInfo(r.Context(), config_obj) - // This should never happen! - if userinfo.Name == "" { - returnError(w, 401, "Unauthenticated access.") - return - } + // This should never happen! + if userinfo.Name == "" { + returnError(w, 401, "Unauthenticated access.") + return + } - users := services.GetUserManager() - user_options, err := users.GetUserOptions(r.Context(), userinfo.Name) - if err != nil { - // Options may not exist yet - user_options = &proto.SetGUIOptionsRequest{} - } + users := services.GetUserManager() + user_options, err := users.GetUserOptions(r.Context(), userinfo.Name) + if err != nil { + // Options may not exist yet + user_options = &proto.SetGUIOptionsRequest{} + } - args := velociraptor.HTMLtemplateArgs{ - Timestamp: time.Now().UTC().UnixNano() / 1000, - CsrfToken: csrf.Token(r), - BasePath: utils.GetBasePath(config_obj), - Heading: "Heading", - UserTheme: user_options.Theme, - OrgId: user_options.Org, - } - err = tmpl.Execute(w, args) - if err != nil { - w.WriteHeader(500) - } - }), nil + args := velociraptor.HTMLtemplateArgs{ + Timestamp: time.Now().UTC().UnixNano() / 1000, + CsrfToken: csrf.Token(r), + BasePath: utils.GetBasePath(config_obj), + Heading: "Heading", + UserTheme: user_options.Theme, + OrgId: user_options.Org, + } + err = tmpl.Execute(w, args) + if err != nil { + w.WriteHeader(500) + } + }), nil } // Vite hard compiles the css urls into the bundle so we can not move @@ -112,18 +114,19 @@ func fixCSSURLs(config_obj *config_proto.Config, parent http.Handler) http.Handler { if config_obj.GUI == nil || config_obj.GUI.BasePath == "" { - return parent + return api_utils.HandlerFunc(parent, parent.ServeHTTP). + AddChild("NewInterceptingResponseWriter") } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.HasSuffix(r.URL.Path, ".css") { - parent.ServeHTTP(w, r) - } else { - parent.ServeHTTP( - NewInterceptingResponseWriter( - w, r, config_obj.GUI.BasePath), r) - } - }) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, ".css") { + parent.ServeHTTP(w, r) + } else { + parent.ServeHTTP( + NewInterceptingResponseWriter(config_obj, w, r), r) + } + }).AddChild("NewInterceptingResponseWriter") } type interceptingResponseWriter struct { @@ -153,8 +156,8 @@ func (self *interceptingResponseWriter) Write(buf []byte) (int, error) { } func NewInterceptingResponseWriter( - w http.ResponseWriter, r *http.Request, - base_path string) http.ResponseWriter { + config_obj *config_proto.Config, + w http.ResponseWriter, r *http.Request) http.ResponseWriter { // Try to do brotli compression if it is available. accept_encoding, pres := r.Header["Accept-Encoding"] @@ -166,8 +169,8 @@ func NewInterceptingResponseWriter( return &interceptingResponseWriter{ ResponseWriter: w, from: "url(/app/assets/", - to: fmt.Sprintf("url(/%v/app/assets/", - strings.TrimPrefix(base_path, "/")), + to: fmt.Sprintf("url(%v/app/assets/", + utils.GetBasePath(config_obj)), br_writer: brotli.NewWriter(w), } } @@ -177,6 +180,7 @@ func NewInterceptingResponseWriter( return &interceptingResponseWriter{ ResponseWriter: w, from: "url(/app/assets/", - to: fmt.Sprintf("url(/%v/app/assets/", base_path), + to: fmt.Sprintf("url(%v/app/assets/", + utils.GetBasePath(config_obj)), } } diff --git a/api/authenticators/auth.go b/api/authenticators/auth.go index 6c471d7c4bc..9ee59d112dc 100644 --- a/api/authenticators/auth.go +++ b/api/authenticators/auth.go @@ -22,8 +22,8 @@ var ( // All SSO Authenticators implement this interface. type Authenticator interface { - AddHandlers(mux *http.ServeMux) error - AddLogoff(mux *http.ServeMux) error + AddHandlers(mux *utils.ServeMux) error + AddLogoff(mux *utils.ServeMux) error // Make sure the user is authenticated and has at least read // access to the requested org. @@ -85,8 +85,6 @@ func init() { return &AzureAuthenticator{ config_obj: config_obj, authenticator: auth_config, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }, nil }) @@ -99,8 +97,6 @@ func init() { return &GitHubAuthenticator{ config_obj: config_obj, authenticator: auth_config, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }, nil }) @@ -113,8 +109,6 @@ func init() { return &GoogleAuthenticator{ config_obj: config_obj, authenticator: auth_config, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }, nil }) @@ -127,8 +121,6 @@ func init() { auth_config *config_proto.Authenticator) (Authenticator, error) { return &BasicAuthenticator{ config_obj: config_obj, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }, nil }) @@ -140,8 +132,6 @@ func init() { result := &CertAuthenticator{ config_obj: config_obj, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), x509_roots: x509.NewCertPool(), default_roles: auth_config.DefaultRolesForUnknownUser, } @@ -162,8 +152,6 @@ func init() { return &OidcAuthenticator{ config_obj: config_obj, authenticator: auth_config, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }, nil }) diff --git a/api/authenticators/azure.go b/api/authenticators/azure.go index 10c1f58697c..4920010c7bf 100644 --- a/api/authenticators/azure.go +++ b/api/authenticators/azure.go @@ -1,19 +1,19 @@ /* - Velociraptor - Dig Deeper - Copyright (C) 2019-2024 Rapid7 Inc. +Velociraptor - Dig Deeper +Copyright (C) 2019-2024 Rapid7 Inc. - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published - by the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published +by the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . */ package authenticators @@ -27,6 +27,7 @@ import ( context "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/microsoft" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" @@ -48,7 +49,7 @@ type AzureAuthenticator struct { // The URL that will be used to log in. func (self *AzureAuthenticator) LoginURL() string { - return utils.Join(self.base, "/auth/azure/login") + return "/auth/azure/login" } func (self *AzureAuthenticator) IsPasswordLess() bool { @@ -63,17 +64,19 @@ func (self *AzureAuthenticator) AuthRedirectTemplate() string { return self.authenticator.AuthRedirectTemplate } -func (self *AzureAuthenticator) AddHandlers(mux *http.ServeMux) error { - mux.Handle(utils.Join(self.base, "/auth/azure/login"), +func (self *AzureAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { + mux.Handle(api_utils.GetBasePath(self.config_obj, self.LoginURL()), IpFilter(self.config_obj, self.oauthAzureLogin())) - mux.Handle(utils.Join(self.base, "/auth/azure/callback"), + + mux.Handle(api_utils.GetBasePath(self.config_obj, "/auth/azure/callback"), IpFilter(self.config_obj, self.oauthAzureCallback())) - mux.Handle(utils.Join(self.base, "/auth/azure/picture"), + + mux.Handle(api_utils.GetBasePath(self.config_obj, "/auth/azure/picture"), IpFilter(self.config_obj, self.oauthAzurePicture())) return nil } -func (self *AzureAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *AzureAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } @@ -86,100 +89,97 @@ func (self *AzureAuthenticator) AuthenticateUserHandler( self.config_obj, func(w http.ResponseWriter, r *http.Request, err error, username string) { reject_with_username(self.config_obj, w, r, err, username, - utils.Join(self.base, "/auth/azure/login"), - "Microsoft O365/Azure AD") + self.LoginURL(), "Microsoft O365/Azure AD") }, parent) } +func (self *AzureAuthenticator) GetGenOauthConfig() (*oauth2.Config, error) { + return &oauth2.Config{ + RedirectURL: utils.GetPublicURL(self.config_obj, "/auth/azure/callback"), + ClientID: self.authenticator.OauthClientId, + ClientSecret: self.authenticator.OauthClientSecret, + Scopes: []string{"User.Read"}, + Endpoint: microsoft.AzureADEndpoint(self.authenticator.Tenant), + }, nil +} + func (self *AzureAuthenticator) oauthAzureLogin() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var azureOauthConfig = &oauth2.Config{ - RedirectURL: utils.Join(self.public_url, self.base, - "/auth/azure/callback"), - ClientID: self.authenticator.OauthClientId, - ClientSecret: self.authenticator.OauthClientSecret, - Scopes: []string{"User.Read"}, - Endpoint: microsoft.AzureADEndpoint(self.authenticator.Tenant), - } - - // Create oauthState cookie - oauthState, err := r.Cookie("oauthstate") - if err != nil { - oauthState = generateStateOauthCookie(self.config_obj, w) - } - - u := azureOauthConfig.AuthCodeURL(oauthState.Value) - http.Redirect(w, r, u, http.StatusTemporaryRedirect) - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + azureOauthConfig, _ := self.GetGenOauthConfig() + + // Create oauthState cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + oauthState = generateStateOauthCookie(self.config_obj, w) + } + + u := azureOauthConfig.AuthCodeURL(oauthState.Value) + http.Redirect(w, r, u, http.StatusTemporaryRedirect) + }) } func (self *AzureAuthenticator) oauthAzureCallback() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, _ := r.Cookie("oauthstate") - - if oauthState == nil || r.FormValue("state") != oauthState.Value { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid oauth azure state") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - user_info, err := self.getUserDataFromAzure( - r.Context(), r.FormValue("code")) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromAzure") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - // Create a new token object, specifying signing method and the claims - // you would like it to contain. - cookie, err := getSignedJWTTokenCookie( - self.config_obj, self.authenticator, - &Claims{ - Username: user_info.Mail, - Picture: utils.Join(self.base, "/auth/azure/picture"), - Token: user_info.Token, - }) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromAzure") + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, _ := r.Cookie("oauthstate") + + if oauthState == nil || r.FormValue("state") != oauthState.Value { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid oauth azure state") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + user_info, err := self.getUserDataFromAzure( + r.Context(), r.FormValue("code")) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromAzure") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + // Create a new token object, specifying signing method and the claims + // you would like it to contain. + cookie, err := getSignedJWTTokenCookie( + self.config_obj, self.authenticator, + &Claims{ + Username: user_info.Mail, + Picture: utils.GetPublicURL(self.config_obj) + + "auth/azure/picture", + Token: user_info.Token, + }) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromAzure") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + http.SetCookie(w, cookie) http.Redirect(w, r, utils.Homepage(self.config_obj), http.StatusTemporaryRedirect) - return - } - - http.SetCookie(w, cookie) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - }) -} - -func (self *AzureAuthenticator) getAzureOauthConfig() *oauth2.Config { - return &oauth2.Config{ - RedirectURL: utils.Join(self.public_url, self.base, - "/auth/azure/callback"), - ClientID: self.authenticator.OauthClientId, - ClientSecret: self.authenticator.OauthClientSecret, - Scopes: []string{"User.Read"}, - Endpoint: microsoft.AzureADEndpoint(self.authenticator.Tenant), - } + }) } func (self *AzureAuthenticator) getUserDataFromAzure( ctx context.Context, code string) (*AzureUser, error) { // Use code to get token and get user info from Azure. - azureOauthConfig := self.getAzureOauthConfig() + azureOauthConfig, err := self.GetGenOauthConfig() + if err != nil { + return nil, err + } token, err := azureOauthConfig.Exchange(ctx, code) if err != nil { @@ -221,40 +221,46 @@ func (self *AzureAuthenticator) getUserDataFromAzure( // Get the token from the cookie and request the picture from Azure func (self *AzureAuthenticator) oauthAzurePicture() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - reject := func(err error) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusUnauthorized) - } - - claims, err := getDetailsFromCookie(self.config_obj, r) - if err != nil { - reject(err) - return - } - - oauth_token := &oauth2.Token{} - err = json.Unmarshal([]byte(claims.Token), &oauth_token) - if err != nil { - reject(err) - return - } - - azureOauthConfig := self.getAzureOauthConfig() - response, err := azureOauthConfig.Client(r.Context(), oauth_token).Get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value") - if err != nil { - reject(fmt.Errorf("failed getting photo: %v", err)) - return - } - defer response.Body.Close() - - _, err = io.Copy(w, response.Body) - if err != nil { - reject(fmt.Errorf("failed getting photo: %v", err)) - return - } - - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + + reject := func(err error) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusUnauthorized) + } + + claims, err := getDetailsFromCookie(self.config_obj, r) + if err != nil { + reject(err) + return + } + + oauth_token := &oauth2.Token{} + err = json.Unmarshal([]byte(claims.Token), &oauth_token) + if err != nil { + reject(err) + return + } + + azureOauthConfig, err := self.GetGenOauthConfig() + if err != nil { + reject(err) + return + } + + response, err := azureOauthConfig.Client(r.Context(), oauth_token).Get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value") + if err != nil { + reject(fmt.Errorf("failed getting photo: %v", err)) + return + } + defer response.Body.Close() + + _, err = io.Copy(w, response.Body) + if err != nil { + reject(fmt.Errorf("failed getting photo: %v", err)) + return + } + + }) } diff --git a/api/authenticators/basic.go b/api/authenticators/basic.go index cf3bab449b7..b4edf913931 100644 --- a/api/authenticators/basic.go +++ b/api/authenticators/basic.go @@ -7,6 +7,7 @@ import ( "github.com/Velocidex/ordereddict" "github.com/gorilla/csrf" api_proto "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" @@ -21,34 +22,35 @@ type BasicAuthenticator struct { } // Basic auth does not need any special handlers. -func (self *BasicAuthenticator) AddHandlers(mux *http.ServeMux) error { +func (self *BasicAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { return nil } -func (self *BasicAuthenticator) AddLogoff(mux *http.ServeMux) error { - mux.Handle(utils.Join(self.base, "/app/logoff.html"), +func (self *BasicAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { + mux.Handle(api_utils.GetBasePath(self.config_obj, "/app/logoff.html"), IpFilter(self.config_obj, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - username, _, ok := r.BasicAuth() - if !ok { + api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + username, _, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + } + + // The previous username is given as a query parameter. + params := r.URL.Query() + old_username, ok := params["username"] + if ok && len(old_username) == 1 && old_username[0] != username { + // Authenticated as someone else. + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) http.Error(w, "authorization failed", http.StatusUnauthorized) - return - } - - // The previous username is given as a query parameter. - params := r.URL.Query() - old_username, ok := params["username"] - if ok && len(old_username) == 1 && old_username[0] != username { - // Authenticated as someone else. - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, "authorization failed", http.StatusUnauthorized) - }))) + }))) return nil } @@ -67,77 +69,78 @@ func (self *BasicAuthenticator) AuthRedirectTemplate() string { func (self *BasicAuthenticator) AuthenticateUserHandler( parent http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-CSRF-Token", csrf.Token(r)) - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - - username, password, ok := r.BasicAuth() - if !ok { - http.Error(w, "Not authorized", http.StatusUnauthorized) - return - } - - // Get the full user record with hashes so we can - // verify it below. - users_manager := services.GetUserManager() - user_record, err := users_manager.GetUserWithHashes(r.Context(), - username, username) - if err != nil { - services.LogAudit(r.Context(), - self.config_obj, username, "Unknown username", - ordereddict.NewDict(). - Set("remote", r.RemoteAddr). - Set("status", http.StatusUnauthorized)) - - http.Error(w, "authorization failed", http.StatusUnauthorized) - return - } - - ok, err = users_manager.VerifyPassword(r.Context(), - user_record.Name, user_record.Name, password) - if !ok || err != nil { - services.LogAudit(r.Context(), - self.config_obj, user_record.Name, "Invalid password", - ordereddict.NewDict(). - Set("remote", r.RemoteAddr). - Set("status", http.StatusUnauthorized)) - - http.Error(w, "authorization failed", http.StatusUnauthorized) - return - } - - // Does the user have access to the specified org? - err = CheckOrgAccess(self.config_obj, r, user_record) - if err != nil { - services.LogAudit(r.Context(), - self.config_obj, user_record.Name, "User Unauthorized for Org", - ordereddict.NewDict(). - Set("err", err.Error()). - Set("remote", r.RemoteAddr). - Set("status", http.StatusUnauthorized)) - - // Return status forbidden because we dont want the user - // to reauthenticate - http.Error(w, err.Error(), http.StatusForbidden) - return - } - - // Checking is successful - user authorized. Here we - // build a token to pass to the underlying GRPC - // service with metadata about the user. - user_info := &api_proto.VelociraptorUser{ - Name: user_record.Name, - } - - // Must use json encoding because grpc can not handle - // binary data in metadata. - serialized, _ := json.Marshal(user_info) - ctx := context.WithValue( - r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) - - // Need to call logging after auth so it can access - // the USER value in the context. - GetLoggingHandler(self.config_obj)(parent).ServeHTTP( - w, r.WithContext(ctx)) - }) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-CSRF-Token", csrf.Token(r)) + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + + username, password, ok := r.BasicAuth() + if !ok { + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + + // Get the full user record with hashes so we can + // verify it below. + users_manager := services.GetUserManager() + user_record, err := users_manager.GetUserWithHashes(r.Context(), + username, username) + if err != nil { + services.LogAudit(r.Context(), + self.config_obj, username, "Unknown username", + ordereddict.NewDict(). + Set("remote", r.RemoteAddr). + Set("status", http.StatusUnauthorized)) + + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + } + + ok, err = users_manager.VerifyPassword(r.Context(), + user_record.Name, user_record.Name, password) + if !ok || err != nil { + services.LogAudit(r.Context(), + self.config_obj, user_record.Name, "Invalid password", + ordereddict.NewDict(). + Set("remote", r.RemoteAddr). + Set("status", http.StatusUnauthorized)) + + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + } + + // Does the user have access to the specified org? + err = CheckOrgAccess(self.config_obj, r, user_record) + if err != nil { + services.LogAudit(r.Context(), + self.config_obj, user_record.Name, "User Unauthorized for Org", + ordereddict.NewDict(). + Set("err", err.Error()). + Set("remote", r.RemoteAddr). + Set("status", http.StatusUnauthorized)) + + // Return status forbidden because we dont want the user + // to reauthenticate + http.Error(w, err.Error(), http.StatusForbidden) + return + } + + // Checking is successful - user authorized. Here we + // build a token to pass to the underlying GRPC + // service with metadata about the user. + user_info := &api_proto.VelociraptorUser{ + Name: user_record.Name, + } + + // Must use json encoding because grpc can not handle + // binary data in metadata. + serialized, _ := json.Marshal(user_info) + ctx := context.WithValue( + r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) + + // Need to call logging after auth so it can access + // the USER value in the context. + GetLoggingHandler(self.config_obj)(parent).ServeHTTP( + w, r.WithContext(ctx)) + }) } diff --git a/api/authenticators/certs.go b/api/authenticators/certs.go index 3937435150d..d3f7c77f3b5 100644 --- a/api/authenticators/certs.go +++ b/api/authenticators/certs.go @@ -106,19 +106,20 @@ type CertAuthenticator struct { } // Cert auth does not need any special handlers. -func (self *CertAuthenticator) AddHandlers(mux *http.ServeMux) error { +func (self *CertAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { return nil } // It is not really possible to log off when using client certs -func (self *CertAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *CertAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { mux.Handle(api_utils.Join(self.base, "/app/logoff.html"), IpFilter(self.config_obj, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, "authorization failed", http.StatusUnauthorized) - return - }))) + api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + }))) return nil } @@ -154,90 +155,91 @@ func (self *CertAuthenticator) getUserNameFromTLSCerts(r *http.Request) (string, func (self *CertAuthenticator) AuthenticateUserHandler( parent http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-CSRF-Token", csrf.Token(r)) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-CSRF-Token", csrf.Token(r)) - username, err := self.getUserNameFromTLSCerts(r) - if err != nil { - http.Error(w, - fmt.Sprintf("authorization failed: Client Certificate is not valid: %v", err), - http.StatusUnauthorized) - return - } - - users_manager := services.GetUserManager() - user_record, err := users_manager.GetUser(r.Context(), username, username) - if err != nil { - if errors.Is(err, utils.NotFoundError) || - len(self.default_roles) == 0 { + username, err := self.getUserNameFromTLSCerts(r) + if err != nil { http.Error(w, - fmt.Sprintf("authorization failed for %v: %v", username, err), + fmt.Sprintf("authorization failed: Client Certificate is not valid: %v", err), http.StatusUnauthorized) return } - // Create a new user role on the fly. - policy := &acl_proto.ApiClientACL{ - Roles: self.default_roles, - } - services.LogAudit(r.Context(), - self.config_obj, username, "Automatic User Creation", - ordereddict.NewDict(). - Set("roles", self.default_roles). - Set("remote", r.RemoteAddr)) - - // Use the super user principal to actually add the - // username so we have enough permissions. - err = users_manager.AddUserToOrg(r.Context(), services.AddNewUser, - utils.GetSuperuserName(self.config_obj), username, - []string{"root"}, policy) + users_manager := services.GetUserManager() + user_record, err := users_manager.GetUser(r.Context(), username, username) if err != nil { - http.Error(w, - fmt.Sprintf("authorization failed: automatic user creation: %v", err), - http.StatusUnauthorized) - return + if errors.Is(err, utils.NotFoundError) || + len(self.default_roles) == 0 { + http.Error(w, + fmt.Sprintf("authorization failed for %v: %v", username, err), + http.StatusUnauthorized) + return + } + + // Create a new user role on the fly. + policy := &acl_proto.ApiClientACL{ + Roles: self.default_roles, + } + services.LogAudit(r.Context(), + self.config_obj, username, "Automatic User Creation", + ordereddict.NewDict(). + Set("roles", self.default_roles). + Set("remote", r.RemoteAddr)) + + // Use the super user principal to actually add the + // username so we have enough permissions. + err = users_manager.AddUserToOrg(r.Context(), services.AddNewUser, + utils.GetSuperuserName(self.config_obj), username, + []string{"root"}, policy) + if err != nil { + http.Error(w, + fmt.Sprintf("authorization failed: automatic user creation: %v", err), + http.StatusUnauthorized) + return + } + + user_record, err = users_manager.GetUser(r.Context(), username, username) + if err != nil { + http.Error(w, + fmt.Sprintf("Failed creating user for %v: %v", username, err), + http.StatusUnauthorized) + return + } } - user_record, err = users_manager.GetUser(r.Context(), username, username) + // Does the user have access to the specified org? + err = CheckOrgAccess(self.config_obj, r, user_record) if err != nil { + services.LogAudit(r.Context(), + self.config_obj, user_record.Name, "Unauthorized username", + ordereddict.NewDict(). + Set("remote", r.RemoteAddr). + Set("status", http.StatusUnauthorized)) + http.Error(w, - fmt.Sprintf("Failed creating user for %v: %v", username, err), + fmt.Sprintf("authorization failed: %v", err), http.StatusUnauthorized) return } - } - - // Does the user have access to the specified org? - err = CheckOrgAccess(self.config_obj, r, user_record) - if err != nil { - services.LogAudit(r.Context(), - self.config_obj, user_record.Name, "Unauthorized username", - ordereddict.NewDict(). - Set("remote", r.RemoteAddr). - Set("status", http.StatusUnauthorized)) - - http.Error(w, - fmt.Sprintf("authorization failed: %v", err), - http.StatusUnauthorized) - return - } - // Checking is successful - user authorized. Here we - // build a token to pass to the underlying GRPC - // service with metadata about the user. - user_info := &api_proto.VelociraptorUser{ - Name: user_record.Name, - } + // Checking is successful - user authorized. Here we + // build a token to pass to the underlying GRPC + // service with metadata about the user. + user_info := &api_proto.VelociraptorUser{ + Name: user_record.Name, + } - // Must use json encoding because grpc can not handle - // binary data in metadata. - serialized, _ := json.Marshal(user_info) - ctx := context.WithValue( - r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) - - // Need to call logging after auth so it can access - // the USER value in the context. - GetLoggingHandler(self.config_obj)(parent).ServeHTTP( - w, r.WithContext(ctx)) - }) + // Must use json encoding because grpc can not handle + // binary data in metadata. + serialized, _ := json.Marshal(user_info) + ctx := context.WithValue( + r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) + + // Need to call logging after auth so it can access + // the USER value in the context. + GetLoggingHandler(self.config_obj)(parent).ServeHTTP( + w, r.WithContext(ctx)) + }) } diff --git a/api/authenticators/github.go b/api/authenticators/github.go index 94dfb879ded..e1868e65504 100644 --- a/api/authenticators/github.go +++ b/api/authenticators/github.go @@ -1,19 +1,19 @@ /* - Velociraptor - Dig Deeper - Copyright (C) 2019-2024 Rapid7 Inc. +Velociraptor - Dig Deeper +Copyright (C) 2019-2024 Rapid7 Inc. - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published - by the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published +by the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . */ package authenticators @@ -27,6 +27,7 @@ import ( context "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/github" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" @@ -47,7 +48,11 @@ type GitHubAuthenticator struct { // The URL that will be used to log in. func (self *GitHubAuthenticator) LoginURL() string { - return utils.Join(self.base, "/auth/github/login") + return "/auth/github/login" +} + +func (self *GitHubAuthenticator) CallbackHandler() string { + return "/auth/github/callback" } func (self *GitHubAuthenticator) IsPasswordLess() bool { @@ -62,15 +67,15 @@ func (self *GitHubAuthenticator) AuthRedirectTemplate() string { return self.authenticator.AuthRedirectTemplate } -func (self *GitHubAuthenticator) AddHandlers(mux *http.ServeMux) error { - mux.Handle(utils.Join(self.base, "/auth/github/login"), +func (self *GitHubAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { + mux.Handle(api_utils.GetBasePath(self.config_obj, self.LoginURL()), IpFilter(self.config_obj, self.oauthGithubLogin())) - mux.Handle(utils.Join(self.base, "/auth/github/callback"), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.CallbackHandler()), IpFilter(self.config_obj, self.oauthGithubCallback())) return nil } -func (self *GitHubAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *GitHubAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } @@ -88,114 +93,112 @@ func (self *GitHubAuthenticator) AuthenticateUserHandler( parent) } +func (self *GitHubAuthenticator) GetGenOauthConfig() (*oauth2.Config, error) { + return &oauth2.Config{ + RedirectURL: utils.GetPublicURL(self.config_obj, "/auth/github/callback"), + ClientID: self.authenticator.OauthClientId, + ClientSecret: self.authenticator.OauthClientSecret, + Scopes: []string{"user:email"}, + Endpoint: github.Endpoint, + }, nil +} + func (self *GitHubAuthenticator) oauthGithubLogin() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var githubOauthConfig = &oauth2.Config{ - RedirectURL: utils.Join(self.public_url, self.base, - "/auth/github/callback"), - ClientID: self.authenticator.OauthClientId, - ClientSecret: self.authenticator.OauthClientSecret, - Scopes: []string{"user:email"}, - Endpoint: github.Endpoint, - } - - // Create oauthState cookie - oauthState, err := r.Cookie("oauthstate") - if err != nil { - oauthState = generateStateOauthCookie(self.config_obj, w) - } - - u := githubOauthConfig.AuthCodeURL(oauthState.Value, oauth2.ApprovalForce) - http.Redirect(w, r, u, http.StatusTemporaryRedirect) - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + githubOauthConfig, _ := self.GetGenOauthConfig() + + // Create oauthState cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + oauthState = generateStateOauthCookie(self.config_obj, w) + } + + u := githubOauthConfig.AuthCodeURL(oauthState.Value, oauth2.ApprovalForce) + http.Redirect(w, r, u, http.StatusTemporaryRedirect) + }) } func (self *GitHubAuthenticator) oauthGithubCallback() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, _ := r.Cookie("oauthstate") + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, _ := r.Cookie("oauthstate") + + if oauthState == nil || r.FormValue("state") != oauthState.Value { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid oauth github state") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } - if oauthState == nil || r.FormValue("state") != oauthState.Value { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid oauth github state") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - formError := r.FormValue("error") - if formError != "" { - desc := r.FormValue("error_description") - if desc != "" { - formError = desc + formError := r.FormValue("error") + if formError != "" { + desc := r.FormValue("error_description") + if desc != "" { + formError = desc + } + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": formError, + }).Error("getUserDataFromGithub") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return } - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": formError, - }).Error("getUserDataFromGithub") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - data, err := self.getUserDataFromGithub(r.Context(), r.FormValue("code")) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGithub") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - user_info := &GitHubUser{} - err = json.Unmarshal(data, &user_info) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGithub") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - cookie, err := getSignedJWTTokenCookie( - self.config_obj, self.authenticator, - &Claims{ - Username: user_info.Login, - Picture: user_info.AvatarUrl, - }) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGithub") + + data, err := self.getUserDataFromGithub(r.Context(), r.FormValue("code")) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGithub") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + user_info := &GitHubUser{} + err = json.Unmarshal(data, &user_info) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGithub") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + cookie, err := getSignedJWTTokenCookie( + self.config_obj, self.authenticator, + &Claims{ + Username: user_info.Login, + Picture: user_info.AvatarUrl, + }) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGithub") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + http.SetCookie(w, cookie) http.Redirect(w, r, utils.Homepage(self.config_obj), http.StatusTemporaryRedirect) - return - } - - http.SetCookie(w, cookie) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - }) + }) } func (self *GitHubAuthenticator) getUserDataFromGithub( ctx context.Context, code string) ([]byte, error) { // Use code to get token and get user info from GitHub. - var githubOauthConfig = &oauth2.Config{ - RedirectURL: utils.Join(self.public_url, self.base, - "/auth/github/callback"), - ClientID: self.authenticator.OauthClientId, - ClientSecret: self.authenticator.OauthClientSecret, - Scopes: []string{}, - Endpoint: github.Endpoint, - } + githubOauthConfig, _ := self.GetGenOauthConfig() token, err := githubOauthConfig.Exchange(ctx, code) if err != nil { diff --git a/api/authenticators/google.go b/api/authenticators/google.go index ef23d2e7dbb..16972b71f94 100644 --- a/api/authenticators/google.go +++ b/api/authenticators/google.go @@ -33,6 +33,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" api_proto "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" @@ -52,35 +53,36 @@ type GoogleAuthenticator struct { } func (self *GoogleAuthenticator) LoginHandler() string { - return utils.Join(self.base, "/auth/google/login") + return "/auth/google/login" } // The URL that will be used to log in. func (self *GoogleAuthenticator) LoginURL() string { - return utils.Join(self.base, "/auth/google/login") + return "/auth/google/login" } func (self *GoogleAuthenticator) CallbackHandler() string { - return utils.Join(self.base, "/auth/google/callback") + return "/auth/google/callback" } func (self *GoogleAuthenticator) CallbackURL() string { - return utils.Join(self.public_url, self.base, "/auth/google/callback") + return "/auth/google/callback" } func (self *GoogleAuthenticator) ProviderName() string { return "Google" } -func (self *GoogleAuthenticator) AddHandlers(mux *http.ServeMux) error { - mux.Handle(self.LoginHandler(), +func (self *GoogleAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { + mux.Handle(api_utils.GetBasePath(self.config_obj, self.LoginHandler()), IpFilter(self.config_obj, self.oauthGoogleLogin())) - mux.Handle(self.CallbackHandler(), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.CallbackHandler()), IpFilter(self.config_obj, self.oauthGoogleCallback())) + return nil } -func (self *GoogleAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *GoogleAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } @@ -112,24 +114,19 @@ func (self *GoogleAuthenticator) AuthenticateUserHandler( func (self *GoogleAuthenticator) oauthGoogleLogin() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var googleOauthConfig = &oauth2.Config{ - RedirectURL: self.CallbackURL(), - ClientID: self.authenticator.OauthClientId, - ClientSecret: self.authenticator.OauthClientSecret, - Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, - Endpoint: google.Endpoint, - } - - // Create oauthState cookie - oauthState, err := r.Cookie("oauthstate") - if err != nil { - oauthState = generateStateOauthCookie(self.config_obj, w) - } - - u := googleOauthConfig.AuthCodeURL(oauthState.Value, oauth2.ApprovalForce) - http.Redirect(w, r, u, http.StatusTemporaryRedirect) - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + googleOauthConfig, _ := self.GetGenOauthConfig() + + // Create oauthState cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + oauthState = generateStateOauthCookie(self.config_obj, w) + } + + u := googleOauthConfig.AuthCodeURL(oauthState.Value, oauth2.ApprovalForce) + http.Redirect(w, r, u, http.StatusTemporaryRedirect) + }) } func generateStateOauthCookie( @@ -154,74 +151,80 @@ func generateStateOauthCookie( func (self *GoogleAuthenticator) oauthGoogleCallback() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, _ := r.Cookie("oauthstate") - if oauthState == nil || r.FormValue("state") != oauthState.Value { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid oauth google state") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - data, err := self.getUserDataFromGoogle(r.Context(), r.FormValue("code")) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGoogle") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - user_info := &api_proto.VelociraptorUser{} - err = json.Unmarshal(data, &user_info) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGoogle") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - // Sign and get the complete encoded token as a string using the secret - cookie, err := getSignedJWTTokenCookie( - self.config_obj, self.authenticator, - &Claims{ - Username: user_info.Email, - Picture: user_info.Picture, - }) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("getUserDataFromGoogle") + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, _ := r.Cookie("oauthstate") + if oauthState == nil || r.FormValue("state") != oauthState.Value { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid oauth google state") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + data, err := self.getUserDataFromGoogle(r.Context(), r.FormValue("code")) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGoogle") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + user_info := &api_proto.VelociraptorUser{} + err = json.Unmarshal(data, &user_info) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGoogle") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + // Sign and get the complete encoded token as a string using the secret + cookie, err := getSignedJWTTokenCookie( + self.config_obj, self.authenticator, + &Claims{ + Username: user_info.Email, + Picture: user_info.Picture, + }) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("getUserDataFromGoogle") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + http.SetCookie(w, cookie) http.Redirect(w, r, utils.Homepage(self.config_obj), http.StatusTemporaryRedirect) - return - } - - http.SetCookie(w, cookie) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - }) + }) } -func (self *GoogleAuthenticator) getUserDataFromGoogle( - ctx context.Context, code string) ([]byte, error) { - - // Use code to get token and get user info from Google. - var googleOauthConfig = &oauth2.Config{ - RedirectURL: self.CallbackURL(), +func (self *GoogleAuthenticator) GetGenOauthConfig() (*oauth2.Config, error) { + res := &oauth2.Config{ + RedirectURL: api_utils.GetPublicURL(self.config_obj, self.CallbackURL()), ClientID: self.authenticator.OauthClientId, ClientSecret: self.authenticator.OauthClientSecret, Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, Endpoint: google.Endpoint, } + return res, nil +} + +func (self *GoogleAuthenticator) getUserDataFromGoogle( + ctx context.Context, code string) ([]byte, error) { + + // Use code to get token and get user info from Google. + googleOauthConfig, _ := self.GetGenOauthConfig() token, err := googleOauthConfig.Exchange(ctx, code) if err != nil { @@ -241,31 +244,32 @@ func (self *GoogleAuthenticator) getUserDataFromGoogle( return contents, nil } -func installLogoff(config_obj *config_proto.Config, mux *http.ServeMux) { - base := utils.GetBasePath(config_obj) - mux.Handle(utils.Join(base, "/app/logoff.html"), IpFilter(config_obj, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := r.URL.Query() - old_username, ok := params["username"] - username := "" - if ok && len(old_username) == 1 { - services.LogAudit(r.Context(), - config_obj, old_username[0], "LogOff", ordereddict.NewDict()) - username = old_username[0] - } - - // Clear the cookie - http.SetCookie(w, &http.Cookie{ - Name: "VelociraptorAuth", - Path: utils.GetBaseDirectory(config_obj), - Value: "deleted", - Secure: true, - HttpOnly: true, - Expires: time.Unix(0, 0), - }) - - renderLogoffMessage(config_obj, w, username) - }))) +func installLogoff(config_obj *config_proto.Config, mux *api_utils.ServeMux) { + mux.Handle(utils.GetBasePath(config_obj, "/app/logoff.html"), + IpFilter(config_obj, + api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + params := r.URL.Query() + old_username, ok := params["username"] + username := "" + if ok && len(old_username) == 1 { + services.LogAudit(r.Context(), + config_obj, old_username[0], "LogOff", ordereddict.NewDict()) + username = old_username[0] + } + + // Clear the cookie + http.SetCookie(w, &http.Cookie{ + Name: "VelociraptorAuth", + Path: utils.GetBaseDirectory(config_obj), + Value: "deleted", + Secure: true, + HttpOnly: true, + Expires: time.Unix(0, 0), + }) + + renderLogoffMessage(config_obj, w, username) + }))) } func authenticateUserHandle( @@ -274,57 +278,58 @@ func authenticateUserHandle( err error, username string), parent http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-CSRF-Token", csrf.Token(r)) - - claims, err := getDetailsFromCookie(config_obj, r) - if err != nil { - reject_cb(w, r, err, claims.Username) - return - } - - username := claims.Username - - // Now check if the user is allowed to log in. - users := services.GetUserManager() - user_record, err := users.GetUser(r.Context(), username, username) - if err != nil { - reject_cb(w, r, fmt.Errorf("Invalid user: %v", err), username) - return - } - - // Does the user have access to the specified org? - err = CheckOrgAccess(config_obj, r, user_record) - if err != nil { - reject_cb(w, r, fmt.Errorf("Insufficient permissions: %v", err), user_record.Name) - return - } - - // Checking is successful - user authorized. Here we - // build a token to pass to the underlying GRPC - // service with metadata about the user. - user_info := &api_proto.VelociraptorUser{ - Name: user_record.Name, - Picture: claims.Picture, - } - - // NOTE: This context is NOT the same context that is received - // by the API handlers. This context sits on the incoming side - // of the GRPC gateway. We stuff our data into the - // GRPC_USER_CONTEXT of the context and the code will convert - // this value into a GRPC metadata. - - // Must use json encoding because grpc can not handle - // binary data in metadata. - serialized, _ := json.Marshal(user_info) - ctx := context.WithValue( - r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) - - // Need to call logging after auth so it can access - // the contextKeyUser value in the context. - GetLoggingHandler(config_obj)(parent).ServeHTTP( - w, r.WithContext(ctx)) - }) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-CSRF-Token", csrf.Token(r)) + + claims, err := getDetailsFromCookie(config_obj, r) + if err != nil { + reject_cb(w, r, err, claims.Username) + return + } + + username := claims.Username + + // Now check if the user is allowed to log in. + users := services.GetUserManager() + user_record, err := users.GetUser(r.Context(), username, username) + if err != nil { + reject_cb(w, r, fmt.Errorf("Invalid user: %v", err), username) + return + } + + // Does the user have access to the specified org? + err = CheckOrgAccess(config_obj, r, user_record) + if err != nil { + reject_cb(w, r, fmt.Errorf("Insufficient permissions: %v", err), user_record.Name) + return + } + + // Checking is successful - user authorized. Here we + // build a token to pass to the underlying GRPC + // service with metadata about the user. + user_info := &api_proto.VelociraptorUser{ + Name: user_record.Name, + Picture: claims.Picture, + } + + // NOTE: This context is NOT the same context that is received + // by the API handlers. This context sits on the incoming side + // of the GRPC gateway. We stuff our data into the + // GRPC_USER_CONTEXT of the context and the code will convert + // this value into a GRPC metadata. + + // Must use json encoding because grpc can not handle + // binary data in metadata. + serialized, _ := json.Marshal(user_info) + ctx := context.WithValue( + r.Context(), constants.GRPC_USER_CONTEXT, string(serialized)) + + // Need to call logging after auth so it can access + // the contextKeyUser value in the context. + GetLoggingHandler(config_obj)(parent).ServeHTTP( + w, r.WithContext(ctx)) + }) } func reject_with_username( @@ -350,7 +355,7 @@ func reject_with_username( renderRejectionMessage(config_obj, r, w, err, username, []velociraptor.AuthenticatorInfo{ { - LoginURL: login_url, + LoginURL: api_utils.PublicURL(config_obj, login_url), ProviderName: provider, }, }) diff --git a/api/authenticators/ip_filter.go b/api/authenticators/ip_filter.go index dab36ac3f42..b7dd56f6460 100644 --- a/api/authenticators/ip_filter.go +++ b/api/authenticators/ip_filter.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" ) @@ -14,7 +15,7 @@ func IpFilter(config_obj *config_proto.Config, parent http.Handler) http.Handler { if config_obj.GUI == nil || len(config_obj.GUI.AllowedCidr) == 0 { - return parent + return api_utils.HandlerFunc(parent, parent.ServeHTTP) } ranges := []*net.IPNet{} @@ -28,32 +29,33 @@ func IpFilter(config_obj *config_proto.Config, ranges = append(ranges, cidr_net) } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { - // If the user specified a forwarded header and the header is - // there we must check it. - if config_obj.GUI.ForwardedProxyHeader != "" { - address_string := r.Header.Get(config_obj.GUI.ForwardedProxyHeader) - ips := strings.Split(address_string, ", ") - if len(ips) > 0 { - // CIDR matched allow it. - if matchCidr(ranges, ips...) { - parent.ServeHTTP(w, r) + // If the user specified a forwarded header and the header is + // there we must check it. + if config_obj.GUI.ForwardedProxyHeader != "" { + address_string := r.Header.Get(config_obj.GUI.ForwardedProxyHeader) + ips := strings.Split(address_string, ", ") + if len(ips) > 0 { + // CIDR matched allow it. + if matchCidr(ranges, ips...) { + parent.ServeHTTP(w, r) + return + } + http.Error(w, "rejected", http.StatusUnauthorized) return } - http.Error(w, "rejected", http.StatusUnauthorized) - return } - } - // Try to check the remote address now. - remote_address := strings.Split(r.RemoteAddr, ":")[0] - if matchCidr(ranges, remote_address) { - parent.ServeHTTP(w, r) - return - } - http.Error(w, "rejected", http.StatusUnauthorized) - }) + // Try to check the remote address now. + remote_address := strings.Split(r.RemoteAddr, ":")[0] + if matchCidr(ranges, remote_address) { + parent.ServeHTTP(w, r) + return + } + http.Error(w, "rejected", http.StatusUnauthorized) + }) } func matchCidr(ranges []*net.IPNet, ip_strings ...string) bool { diff --git a/api/authenticators/logging.go b/api/authenticators/logging.go index f60c57d5103..64be2508757 100644 --- a/api/authenticators/logging.go +++ b/api/authenticators/logging.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" api_proto "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" "www.velocidex.com/golang/velociraptor/json" @@ -33,39 +34,40 @@ func GetLoggingHandler(config_obj *config_proto.Config) func(http.Handler) http. logger := logging.GetLogger(config_obj, &logging.GUIComponent) return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rec := &http_utils.StatusRecorder{ - w, - w.(http.Flusher), - 200, nil} - defer func() { - if rec.Status == 500 { - logger.WithFields( - logrus.Fields{ - "method": r.Method, - "url": r.URL.Path, - "remote": r.RemoteAddr, - "error": string(rec.Error), - "user-agent": r.UserAgent(), - "status": rec.Status, - "user": GetUserInfo( - r.Context(), config_obj).Name, - }).Error("") + return api_utils.HandlerFunc(next, + func(w http.ResponseWriter, r *http.Request) { + rec := &http_utils.StatusRecorder{ + w, + w.(http.Flusher), + 200, nil} + defer func() { + if rec.Status == 500 { + logger.WithFields( + logrus.Fields{ + "method": r.Method, + "url": r.URL.Path, + "remote": r.RemoteAddr, + "error": string(rec.Error), + "user-agent": r.UserAgent(), + "status": rec.Status, + "user": GetUserInfo( + r.Context(), config_obj).Name, + }).Error("") - } else { - logger.WithFields( - logrus.Fields{ - "method": r.Method, - "url": r.URL.Path, - "remote": r.RemoteAddr, - "user-agent": r.UserAgent(), - "status": rec.Status, - "user": GetUserInfo( - r.Context(), config_obj).Name, - }).Info("") - } - }() - next.ServeHTTP(rec, r) - }) + } else { + logger.WithFields( + logrus.Fields{ + "method": r.Method, + "url": r.URL.Path, + "remote": r.RemoteAddr, + "user-agent": r.UserAgent(), + "status": rec.Status, + "user": GetUserInfo( + r.Context(), config_obj).Name, + }).Info("") + } + }() + next.ServeHTTP(rec, r) + }) } } diff --git a/api/authenticators/multiple.go b/api/authenticators/multiple.go index be8d6031353..449729f3c92 100644 --- a/api/authenticators/multiple.go +++ b/api/authenticators/multiple.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/Velocidex/ordereddict" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/gui/velociraptor" "www.velocidex.com/golang/velociraptor/services" @@ -16,7 +17,11 @@ type MultiAuthenticator struct { delegate_info []velociraptor.AuthenticatorInfo } -func (self *MultiAuthenticator) AddHandlers(mux *http.ServeMux) error { +func (self *MultiAuthenticator) Delegates() []Authenticator { + return self.delegates +} + +func (self *MultiAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { for _, delegate := range self.delegates { err := delegate.AddHandlers(mux) if err != nil { @@ -26,7 +31,7 @@ func (self *MultiAuthenticator) AddHandlers(mux *http.ServeMux) error { return nil } -func (self *MultiAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *MultiAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } diff --git a/api/authenticators/oidc.go b/api/authenticators/oidc.go index 899e11b9a59..c0bb8f90ff2 100644 --- a/api/authenticators/oidc.go +++ b/api/authenticators/oidc.go @@ -8,11 +8,16 @@ import ( oidc "github.com/coreos/go-oidc/v3/oidc" "github.com/sirupsen/logrus" "golang.org/x/oauth2" - utils "www.velocidex.com/golang/velociraptor/api/utils" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/logging" + "www.velocidex.com/golang/velociraptor/utils" ) +type OIDCConnector interface { + GetGenOauthConfig() (*oauth2.Config, error) +} + type OidcAuthenticator struct { config_obj *config_proto.Config authenticator *config_proto.Authenticator @@ -42,33 +47,37 @@ func (self *OidcAuthenticator) Name() string { func (self *OidcAuthenticator) LoginHandler() string { name := self.authenticator.OidcName if name != "" { - return utils.Join(self.base, "/auth/oidc/", name, "/login") + return api_utils.Join("/auth/oidc/", name, "/login") } - return utils.Join(self.base, "/auth/oidc/login") + return "/auth/oidc/login" } func (self *OidcAuthenticator) LoginURL() string { - return utils.Join(self.public_url, self.LoginHandler()) + return self.LoginHandler() } func (self *OidcAuthenticator) CallbackHandler() string { name := self.authenticator.OidcName if name != "" { - return utils.Join(self.base, "/auth/oidc/", name, "/callback") + return api_utils.Join("/auth/oidc/", name, "/callback") } - return utils.Join(self.base, "/auth/oidc/callback") + return "/auth/oidc/callback" } func (self *OidcAuthenticator) CallbackURL() string { - return utils.Join(self.public_url, self.LoginHandler()) + return self.LoginHandler() } -func (self *OidcAuthenticator) AddHandlers(mux *http.ServeMux) error { +func (self *OidcAuthenticator) GetProvider() (*oidc.Provider, error) { ctx, err := ClientContext(context.Background(), self.config_obj) if err != nil { - return err + return nil, err } - provider, err := oidc.NewProvider(ctx, self.authenticator.OidcIssuer) + return oidc.NewProvider(ctx, self.authenticator.OidcIssuer) +} + +func (self *OidcAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { + provider, err := self.GetProvider() if err != nil { logging.GetLogger(self.config_obj, &logging.GUIComponent). Errorf("can not get information from OIDC provider, "+ @@ -77,14 +86,14 @@ func (self *OidcAuthenticator) AddHandlers(mux *http.ServeMux) error { return err } - mux.Handle(self.LoginHandler(), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.LoginHandler()), IpFilter(self.config_obj, self.oauthOidcLogin(provider))) - mux.Handle(self.CallbackHandler(), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.CallbackHandler()), IpFilter(self.config_obj, self.oauthOidcCallback(provider))) return nil } -func (self *OidcAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *OidcAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } @@ -100,8 +109,9 @@ func (self *OidcAuthenticator) AuthenticateUserHandler( parent) } -func (self *OidcAuthenticator) getGenOauthConfig( - endpoint oauth2.Endpoint, callback string) *oauth2.Config { +func (self *OidcAuthenticator) GetGenOauthConfig() (*oauth2.Config, error) { + + callback := self.CallbackHandler() var scope []string switch strings.ToLower(self.authenticator.Type) { @@ -110,99 +120,113 @@ func (self *OidcAuthenticator) getGenOauthConfig( } return &oauth2.Config{ - RedirectURL: self.config_obj.GUI.PublicUrl + callback[1:], + RedirectURL: api_utils.GetPublicURL(self.config_obj, callback), ClientID: self.authenticator.OauthClientId, ClientSecret: self.authenticator.OauthClientSecret, Scopes: scope, - Endpoint: endpoint, - } + }, nil } func (self *OidcAuthenticator) oauthOidcLogin( provider *oidc.Provider) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - oidcOauthConfig := self.getGenOauthConfig( - provider.Endpoint(), self.CallbackHandler()) - - // Create oauthState cookie - oauthState, err := r.Cookie("oauthstate") - if err != nil { - oauthState = generateStateOauthCookie(self.config_obj, w) - } - - url := oidcOauthConfig.AuthCodeURL(oauthState.Value) - - // Needed for Okta to specify `prompt: login` to avoid consent - // auth on each login. - if self.authenticator.OidcAuthUrlParams != nil { - for k, v := range self.authenticator.OidcAuthUrlParams { - oauth2.SetAuthURLParam(k, v) + + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + oidcOauthConfig, err := self.GetGenOauthConfig() + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("GetGenOauthConfig: %v", err) + http.Error(w, "rejected", http.StatusUnauthorized) + } + oidcOauthConfig.Endpoint = provider.Endpoint() + + utils.Debug(oidcOauthConfig) + + // Create oauthState cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + oauthState = generateStateOauthCookie(self.config_obj, w) + } + + url := oidcOauthConfig.AuthCodeURL(oauthState.Value) + + // Needed for Okta to specify `prompt: login` to avoid consent + // auth on each login. + if self.authenticator.OidcAuthUrlParams != nil { + for k, v := range self.authenticator.OidcAuthUrlParams { + oauth2.SetAuthURLParam(k, v) + } } - } - http.Redirect(w, r, url, http.StatusFound) - }) + http.Redirect(w, r, url, http.StatusFound) + }) } func (self *OidcAuthenticator) oauthOidcCallback( provider *oidc.Provider) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, _ := r.Cookie("oauthstate") - if oauthState == nil || r.FormValue("state") != oauthState.Value { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid oauth state of OIDC") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, _ := r.Cookie("oauthstate") + if oauthState == nil || r.FormValue("state") != oauthState.Value { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid oauth state of OIDC") + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } - oidcOauthConfig := self.getGenOauthConfig( - provider.Endpoint(), self.CallbackHandler()) + oidcOauthConfig, err := self.GetGenOauthConfig() + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("GetGenOauthConfig: %v", err) + http.Error(w, "rejected", http.StatusUnauthorized) + } + oidcOauthConfig.Endpoint = provider.Endpoint() + + ctx, err := ClientContext(r.Context(), self.config_obj) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid client context of OIDC") + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + oauthToken, err := oidcOauthConfig.Exchange(ctx, r.FormValue("code")) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("can not get oauthToken from OIDC provider: %v", err) + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + userInfo, err := provider.UserInfo( + ctx, oauth2.StaticTokenSource(oauthToken)) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("can not get UserInfo from OIDC provider: %v", err) + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } - ctx, err := ClientContext(r.Context(), self.config_obj) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid client context of OIDC") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - oauthToken, err := oidcOauthConfig.Exchange(ctx, r.FormValue("code")) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("can not get oauthToken from OIDC provider: %v", err) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - userInfo, err := provider.UserInfo( - ctx, oauth2.StaticTokenSource(oauthToken)) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("can not get UserInfo from OIDC provider: %v", err) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - cookie, err := getSignedJWTTokenCookie( - self.config_obj, self.authenticator, - &Claims{ - Username: userInfo.Email, - }) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("can not get a signed tokenString") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } + cookie, err := getSignedJWTTokenCookie( + self.config_obj, self.authenticator, + &Claims{ + Username: userInfo.Email, + }) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("can not get a signed tokenString") + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } - http.SetCookie(w, cookie) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - }) + http.SetCookie(w, cookie) + http.Redirect(w, r, api_utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + }) } diff --git a/api/authenticators/oidc_cognito.go b/api/authenticators/oidc_cognito.go index 0ce692ce71a..c0038d979a7 100644 --- a/api/authenticators/oidc_cognito.go +++ b/api/authenticators/oidc_cognito.go @@ -12,6 +12,7 @@ import ( oidc "github.com/coreos/go-oidc/v3/oidc" "github.com/sirupsen/logrus" "golang.org/x/oauth2" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/json" @@ -25,9 +26,8 @@ type OidcAuthenticatorCognito struct { OidcAuthenticator } -func (self *OidcAuthenticatorCognito) AddHandlers(mux *http.ServeMux) error { - provider, err := oidc.NewProvider( - context.Background(), self.authenticator.OidcIssuer) +func (self *OidcAuthenticatorCognito) AddHandlers(mux *api_utils.ServeMux) error { + provider, err := self.GetProvider() if err != nil { logging.GetLogger(self.config_obj, &logging.GUIComponent). Errorf("can not get information from OIDC provider, "+ @@ -36,65 +36,73 @@ func (self *OidcAuthenticatorCognito) AddHandlers(mux *http.ServeMux) error { return err } - mux.Handle(self.LoginHandler(), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.LoginHandler()), IpFilter(self.config_obj, self.oauthOidcLogin(provider))) - mux.Handle(self.CallbackHandler(), + mux.Handle(api_utils.GetBasePath(self.config_obj, self.CallbackHandler()), IpFilter(self.config_obj, self.oauthOidcCallback(provider))) return nil } func (self *OidcAuthenticatorCognito) oauthOidcCallback( provider *oidc.Provider) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, _ := r.Cookie("oauthstate") - if oauthState == nil || r.FormValue("state") != oauthState.Value { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("invalid oauth state of OIDC") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - oidcOauthConfig := self.getGenOauthConfig( - provider.Endpoint(), self.CallbackHandler()) - oauthToken, err := oidcOauthConfig.Exchange(r.Context(), r.FormValue("code")) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("can not get oauthToken from OIDC provider: %v", err) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, _ := r.Cookie("oauthstate") + if oauthState == nil || r.FormValue("state") != oauthState.Value { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("invalid oauth state of OIDC") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + oidcOauthConfig, err := self.GetGenOauthConfig() + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("GetGenOauthConfig: %v", err) + http.Error(w, "rejected", http.StatusUnauthorized) + } + oidcOauthConfig.Endpoint = provider.Endpoint() + + oauthToken, err := oidcOauthConfig.Exchange(r.Context(), r.FormValue("code")) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("can not get oauthToken from OIDC provider: %v", err) + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + userInfo, err := getUserInfo( + r.Context(), provider, oauth2.StaticTokenSource(oauthToken)) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + Error("can not get UserInfo from OIDC provider: %v", err) + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + cookie, err := getSignedJWTTokenCookie( + self.config_obj, self.authenticator, + &Claims{ + Username: userInfo.Email, + }) + if err != nil { + logging.GetLogger(self.config_obj, &logging.GUIComponent). + WithFields(logrus.Fields{ + "err": err.Error(), + }).Error("can not get a signed tokenString") + http.Redirect(w, r, utils.Homepage(self.config_obj), + http.StatusTemporaryRedirect) + return + } + + http.SetCookie(w, cookie) http.Redirect(w, r, utils.Homepage(self.config_obj), http.StatusTemporaryRedirect) - return - } - userInfo, err := getUserInfo( - r.Context(), provider, oauth2.StaticTokenSource(oauthToken)) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - Error("can not get UserInfo from OIDC provider: %v", err) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - cookie, err := getSignedJWTTokenCookie( - self.config_obj, self.authenticator, - &Claims{ - Username: userInfo.Email, - }) - if err != nil { - logging.GetLogger(self.config_obj, &logging.GUIComponent). - WithFields(logrus.Fields{ - "err": err.Error(), - }).Error("can not get a signed tokenString") - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - return - } - - http.SetCookie(w, cookie) - http.Redirect(w, r, utils.Homepage(self.config_obj), - http.StatusTemporaryRedirect) - }) + }) } func init() { @@ -103,8 +111,6 @@ func init() { return &OidcAuthenticatorCognito{OidcAuthenticator{ config_obj: config_obj, authenticator: auth_config, - base: utils.GetBasePath(config_obj), - public_url: utils.GetPublicURL(config_obj), }}, nil }) } diff --git a/api/authenticators/saml.go b/api/authenticators/saml.go index edaa636e6b1..9ff66161787 100644 --- a/api/authenticators/saml.go +++ b/api/authenticators/saml.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/csrf" acl_proto "www.velocidex.com/golang/velociraptor/acls/proto" api_proto "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/constants" crypto_utils "www.velocidex.com/golang/velociraptor/crypto/utils" @@ -43,8 +44,7 @@ func (self *SamlAuthenticator) AuthRedirectTemplate() string { return self.authenticator.AuthRedirectTemplate } -func (self *SamlAuthenticator) AddHandlers(mux *http.ServeMux) error { - +func (self *SamlAuthenticator) AddHandlers(mux *api_utils.ServeMux) error { logger := logging.Manager.GetLogger(self.config_obj, &logging.GUIComponent) key, err := crypto_utils.ParseRsaPrivateKeyFromPemStr([]byte( self.authenticator.SamlPrivateKey)) @@ -99,12 +99,13 @@ func (self *SamlAuthenticator) AddHandlers(mux *http.ServeMux) error { cookieSessionProvider.Codec = jwtSessionCodec samlMiddleware.Session = cookieSessionProvider - mux.Handle("/saml/", IpFilter(self.config_obj, samlMiddleware)) + mux.Handle(api_utils.GetBasePath(self.config_obj, "/saml/"), + IpFilter(self.config_obj, samlMiddleware)) logger.Info("Authentication via SAML enabled") return nil } -func (self *SamlAuthenticator) AddLogoff(mux *http.ServeMux) error { +func (self *SamlAuthenticator) AddLogoff(mux *api_utils.ServeMux) error { installLogoff(self.config_obj, mux) return nil } @@ -114,114 +115,115 @@ func (self *SamlAuthenticator) AuthenticateUserHandler( reject_handler := samlMiddleware.RequireAccount(parent) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-CSRF-Token", csrf.Token(r)) - - session, err := samlMiddleware.Session.GetSession(r) - if session == nil { - reject_handler.ServeHTTP(w, r) - return - } - - sa, ok := session.(samlsp.SessionWithAttributes) - if !ok { - reject_handler.ServeHTTP(w, r) - return - } - - username := sa.GetAttributes().Get(self.user_attribute) - users := services.GetUserManager() - user_record, err := users.GetUser(r.Context(), username, username) - if err != nil { - if !errors.Is(err, utils.NotFoundError) { - http.Error(w, - fmt.Sprintf("authorization failed: %v", err), - http.StatusUnauthorized) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-CSRF-Token", csrf.Token(r)) - services.LogAudit(r.Context(), - self.config_obj, username, "Authorization failed", - ordereddict.NewDict(). - Set("error", err). - Set("username", username). - Set("roles", self.user_roles). - Set("remote", r.RemoteAddr)) + session, err := samlMiddleware.Session.GetSession(r) + if session == nil { + reject_handler.ServeHTTP(w, r) return } - if len(self.user_roles) == 0 { - http.Error(w, - "authorization failed: no saml user roles assigned", - http.StatusUnauthorized) + sa, ok := session.(samlsp.SessionWithAttributes) + if !ok { + reject_handler.ServeHTTP(w, r) + return + } + username := sa.GetAttributes().Get(self.user_attribute) + users := services.GetUserManager() + user_record, err := users.GetUser(r.Context(), username, username) + if err != nil { + if !errors.Is(err, utils.NotFoundError) { + http.Error(w, + fmt.Sprintf("authorization failed: %v", err), + http.StatusUnauthorized) + + services.LogAudit(r.Context(), + self.config_obj, username, "Authorization failed", + ordereddict.NewDict(). + Set("error", err). + Set("username", username). + Set("roles", self.user_roles). + Set("remote", r.RemoteAddr)) + return + } + + if len(self.user_roles) == 0 { + http.Error(w, + "authorization failed: no saml user roles assigned", + http.StatusUnauthorized) + + services.LogAudit(r.Context(), + self.config_obj, username, "Authorization failed: no saml user roles assigned", + ordereddict.NewDict(). + Set("username", username). + Set("roles", self.user_roles). + Set("remote", r.RemoteAddr)) + return + } + + // Create a new user role on the fly. + policy := &acl_proto.ApiClientACL{ + Roles: self.user_roles, + } services.LogAudit(r.Context(), - self.config_obj, username, "Authorization failed: no saml user roles assigned", + self.config_obj, username, "Automatic User Creation", ordereddict.NewDict(). Set("username", username). Set("roles", self.user_roles). Set("remote", r.RemoteAddr)) - return - } - // Create a new user role on the fly. - policy := &acl_proto.ApiClientACL{ - Roles: self.user_roles, + // Use the super user principal to actually add the + // username so we have enough permissions. + err = users.AddUserToOrg(r.Context(), services.AddNewUser, + constants.PinnedServerName, username, + []string{"root"}, policy) + if err != nil { + http.Error(w, + fmt.Sprintf("authorization failed: automatic user creation: %v", err), + http.StatusUnauthorized) + return + } + + user_record, err = users.GetUser(r.Context(), username, username) + if err != nil { + http.Error(w, + fmt.Sprintf("Failed creating user for %v: %v", username, err), + http.StatusUnauthorized) + return + } } - services.LogAudit(r.Context(), - self.config_obj, username, "Automatic User Creation", - ordereddict.NewDict(). - Set("username", username). - Set("roles", self.user_roles). - Set("remote", r.RemoteAddr)) - - // Use the super user principal to actually add the - // username so we have enough permissions. - err = users.AddUserToOrg(r.Context(), services.AddNewUser, - constants.PinnedServerName, username, - []string{"root"}, policy) + + // Does the user have access to the specified org? + err = CheckOrgAccess(self.config_obj, r, user_record) if err != nil { + services.LogAudit(r.Context(), + self.config_obj, username, "authorization failed: user not registered and no saml_user_roles set", + ordereddict.NewDict(). + Set("username", username). + Set("roles", self.user_roles). + Set("remote", r.RemoteAddr). + Set("status", http.StatusUnauthorized)) + http.Error(w, - fmt.Sprintf("authorization failed: automatic user creation: %v", err), + fmt.Sprintf("authorization failed: user not registered - contact your system administrator: %v", err), http.StatusUnauthorized) return } - user_record, err = users.GetUser(r.Context(), username, username) - if err != nil { - http.Error(w, - fmt.Sprintf("Failed creating user for %v: %v", username, err), - http.StatusUnauthorized) - return + user_info := &api_proto.VelociraptorUser{ + Name: user_record.Name, } - } - - // Does the user have access to the specified org? - err = CheckOrgAccess(self.config_obj, r, user_record) - if err != nil { - services.LogAudit(r.Context(), - self.config_obj, username, "authorization failed: user not registered and no saml_user_roles set", - ordereddict.NewDict(). - Set("username", username). - Set("roles", self.user_roles). - Set("remote", r.RemoteAddr). - Set("status", http.StatusUnauthorized)) - - http.Error(w, - fmt.Sprintf("authorization failed: user not registered - contact your system administrator: %v", err), - http.StatusUnauthorized) - return - } - - user_info := &api_proto.VelociraptorUser{ - Name: user_record.Name, - } - - serialized, _ := json.Marshal(user_info) - ctx := context.WithValue( - r.Context(), constants.GRPC_USER_CONTEXT, - string(serialized)) - GetLoggingHandler(self.config_obj)(parent).ServeHTTP( - w, r.WithContext(ctx)) - }) + + serialized, _ := json.Marshal(user_info) + ctx := context.WithValue( + r.Context(), constants.GRPC_USER_CONTEXT, + string(serialized)) + GetLoggingHandler(self.config_obj)(parent).ServeHTTP( + w, r.WithContext(ctx)) + }) } func NewSamlAuthenticator( diff --git a/api/authenticators/template.go b/api/authenticators/template.go index 9ad8bfee001..ce1a1bdf730 100644 --- a/api/authenticators/template.go +++ b/api/authenticators/template.go @@ -75,7 +75,7 @@ func renderLogoffMessage( ErrState: json.MustMarshalString(velociraptor.ErrState{ Type: "Logoff", Username: username, - BasePath: utils.Join(utils.GetBasePath(config_obj), "/"), + BasePath: utils.GetBaseDirectory(config_obj), Authenticators: []velociraptor.AuthenticatorInfo{}, }), }) diff --git a/api/builder.go b/api/builder.go index 6842d9c6f7a..32844ae5c06 100644 --- a/api/builder.go +++ b/api/builder.go @@ -15,6 +15,7 @@ import ( "golang.org/x/crypto/acme/autocert" "www.velocidex.com/golang/velociraptor/api/authenticators" + utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/logging" "www.velocidex.com/golang/velociraptor/server" @@ -116,9 +117,9 @@ func (self *Builder) withAutoCertFrontendSelfSignedGUI( logger.Info("Autocert is enabled but GUI port is not 443, starting Frontend with autocert and GUI with self signed.") if config_obj.Services.GuiServer && config_obj.GUI != nil { - mux := http.NewServeMux() + mux := utils.NewServeMux() - router, err := PrepareGUIMux(ctx, config_obj, server_obj, mux) + router, err := PrepareGUIMux(ctx, config_obj, mux) if err != nil { return err } @@ -139,10 +140,10 @@ func (self *Builder) withAutoCertFrontendSelfSignedGUI( } // Launch a server for the frontend. - mux := http.NewServeMux() + mux := utils.NewServeMux() err := server.PrepareFrontendMux( - config_obj, server_obj, mux) + config_obj, server_obj, mux.ServeMux) if err != nil { return err } @@ -162,16 +163,17 @@ func (self *Builder) WithAutocertGUI( return errors.New("Frontend not configured") } - mux := http.NewServeMux() + mux := utils.NewServeMux() if self.config_obj.Services.FrontendServer { - err := server.PrepareFrontendMux(self.config_obj, self.server_obj, mux) + err := server.PrepareFrontendMux( + self.config_obj, self.server_obj, mux.ServeMux) if err != nil { return err } } - router, err := PrepareGUIMux(ctx, self.config_obj, self.server_obj, mux) + router, err := PrepareGUIMux(ctx, self.config_obj, mux) if err != nil { return err } @@ -188,20 +190,21 @@ func startSharedSelfSignedFrontend( wg *sync.WaitGroup, config_obj *config_proto.Config, server_obj *server.Server) error { - mux := http.NewServeMux() + mux := utils.NewServeMux() if config_obj.Frontend == nil || config_obj.GUI == nil { return errors.New("Frontend not configured") } if config_obj.Services.FrontendServer { - err := server.PrepareFrontendMux(config_obj, server_obj, mux) + err := server.PrepareFrontendMux( + config_obj, server_obj, mux.ServeMux) if err != nil { return err } } - router, err := PrepareGUIMux(ctx, config_obj, server_obj, mux) + router, err := PrepareGUIMux(ctx, config_obj, mux) if err != nil { return err } @@ -248,9 +251,9 @@ func startSelfSignedFrontend( // Launch a new server for the GUI. if config_obj.Services.GuiServer { - mux := http.NewServeMux() + mux := utils.NewServeMux() - router, err := PrepareGUIMux(ctx, config_obj, server_obj, mux) + router, err := PrepareGUIMux(ctx, config_obj, mux) if err != nil { return err } @@ -271,10 +274,10 @@ func startSelfSignedFrontend( } // Launch a server for the frontend. - mux := http.NewServeMux() + mux := utils.NewServeMux() server.PrepareFrontendMux( - config_obj, server_obj, mux) + config_obj, server_obj, mux.ServeMux) if config_obj.Frontend.UsePlainHttp { return StartFrontendPlainHttp( diff --git a/api/csrf.go b/api/csrf.go index a1e59851a92..0bcba0350fb 100644 --- a/api/csrf.go +++ b/api/csrf.go @@ -6,6 +6,7 @@ import ( "os" "github.com/gorilla/csrf" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/logging" ) @@ -19,7 +20,7 @@ func csrfProtect(config_obj *config_proto.Config, if pres && disable_csrf == "1" { logger := logging.GetLogger(config_obj, &logging.GUIComponent) logger.Info("Disabling CSRF protection because environment VELOCIRAPTOR_DISABLE_CSRF is set") - return parent + return api_utils.HandlerFunc(parent, parent.ServeHTTP) } // Derive a CSRF key from the hash of the server's public key. @@ -29,7 +30,8 @@ func csrfProtect(config_obj *config_proto.Config, protectionFn := csrf.Protect(token, csrf.Path("/"), csrf.MaxAge(7*24*60*60)) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - protectionFn(parent).ServeHTTP(w, r) - }) + return api_utils.HandlerFunc(parent, + func(w http.ResponseWriter, r *http.Request) { + protectionFn(parent).ServeHTTP(w, r) + }) } diff --git a/api/download.go b/api/download.go index 22a1798defe..f700bc1dcc2 100644 --- a/api/download.go +++ b/api/download.go @@ -47,6 +47,7 @@ import ( "www.velocidex.com/golang/velociraptor/api/authenticators" api_proto "www.velocidex.com/golang/velociraptor/api/proto" "www.velocidex.com/golang/velociraptor/api/tables" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" config_proto "www.velocidex.com/golang/velociraptor/config/proto" "www.velocidex.com/golang/velociraptor/file_store" "www.velocidex.com/golang/velociraptor/file_store/api" @@ -113,195 +114,196 @@ type vfsFileDownloadRequest struct { // This URL allows the caller to download **any** member of the // filestore (providing they have at least read permissions). func vfsFileDownloadHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - request := vfsFileDownloadRequest{} - decoder := schema.NewDecoder() - decoder.IgnoreUnknownKeys(true) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + request := vfsFileDownloadRequest{} + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) - err := decoder.Decode(&request, r.URL.Query()) - if err != nil { - returnError(w, 403, "Error "+err.Error()) - return - } + err := decoder.Decode(&request, r.URL.Query()) + if err != nil { + returnError(w, 403, "Error "+err.Error()) + return + } - org_id := request.OrgId - if org_id == "" { - org_id = authenticators.GetOrgIdFromRequest(r) - } + org_id := request.OrgId + if org_id == "" { + org_id = authenticators.GetOrgIdFromRequest(r) + } - org_id = utils.NormalizedOrgId(org_id) + org_id = utils.NormalizedOrgId(org_id) - org_manager, err := services.GetOrgManager() - if err != nil { - returnError(w, 404, err.Error()) - return - } + org_manager, err := services.GetOrgManager() + if err != nil { + returnError(w, 404, err.Error()) + return + } - org_config_obj, err := org_manager.GetOrgConfig(org_id) - if err != nil { - returnError(w, 404, err.Error()) - return - } + org_config_obj, err := org_manager.GetOrgConfig(org_id) + if err != nil { + returnError(w, 404, err.Error()) + return + } + + // Where to read from the file store + var path_spec api.FSPathSpec - // Where to read from the file store - var path_spec api.FSPathSpec + // The filename for the attachment header. + var filename string - // The filename for the attachment header. - var filename string + client_path_manager := paths.NewClientPathManager(request.ClientId) - client_path_manager := paths.NewClientPathManager(request.ClientId) + // Newer API calls pass the filestore components directly + if len(request.FSComponents) > 0 { + path_spec = path_specs.NewUnsafeFilestorePath(request.FSComponents...). + SetType(api.PATH_TYPE_FILESTORE_ANY) - // Newer API calls pass the filestore components directly - if len(request.FSComponents) > 0 { - path_spec = path_specs.NewUnsafeFilestorePath(request.FSComponents...). - SetType(api.PATH_TYPE_FILESTORE_ANY) + filename = utils.Base(request.VfsPath) - filename = utils.Base(request.VfsPath) + // Uploads table has direct vfs paths + } else if request.VfsPath != "" { + path_spec, err = client_path_manager.GetUploadsFileFromVFSPath( + request.VfsPath) + if err != nil { + returnError(w, 404, err.Error()) + return + } + filename = path_spec.Base() - // Uploads table has direct vfs paths - } else if request.VfsPath != "" { - path_spec, err = client_path_manager.GetUploadsFileFromVFSPath( - request.VfsPath) + } else { + // Just reject the request + returnError(w, 404, "") + return + } + + file, err := file_store.GetFileStore(org_config_obj).ReadFile(path_spec) if err != nil { returnError(w, 404, err.Error()) return } - filename = path_spec.Base() - - } else { - // Just reject the request - returnError(w, 404, "") - return - } - - file, err := file_store.GetFileStore(org_config_obj).ReadFile(path_spec) - if err != nil { - returnError(w, 404, err.Error()) - return - } - defer file.Close() - - if r.Method == "HEAD" { - returnError(w, 200, "Ok") - return - } + defer file.Close() - // We need to figure out the total size of the upload to set - // in the Content Length header. There are three - // possibilities: - // 1. The file is not sparse - // 2. The file is sparse and we are not padding. - // 3. The file is sparse and we are padding it. - var reader_at io.ReaderAt = utils.MakeReaderAtter(file) - var total_size int - - index, err := getIndex(org_config_obj, path_spec) - - // If the file is sparse, we use the sparse reader. - if err == nil && request.Padding && len(index.Ranges) > 0 { - if !uploads.ShouldPadFile(org_config_obj, index) { - returnError(w, 400, "Sparse file is too sparse - unable to pad") + if r.Method == "HEAD" { + returnError(w, 200, "Ok") return } - reader_at = &utils.RangedReader{ - ReaderAt: reader_at, - Index: index, - } + // We need to figure out the total size of the upload to set + // in the Content Length header. There are three + // possibilities: + // 1. The file is not sparse + // 2. The file is sparse and we are not padding. + // 3. The file is sparse and we are padding it. + var reader_at io.ReaderAt = utils.MakeReaderAtter(file) + var total_size int + + index, err := getIndex(org_config_obj, path_spec) + + // If the file is sparse, we use the sparse reader. + if err == nil && request.Padding && len(index.Ranges) > 0 { + if !uploads.ShouldPadFile(org_config_obj, index) { + returnError(w, 400, "Sparse file is too sparse - unable to pad") + return + } - total_size = calculateTotalSizeWithPadding(index) - } else { - total_size = calculateTotalReaderSize(file) - } + reader_at = &utils.RangedReader{ + ReaderAt: reader_at, + Index: index, + } - if request.TextFilter { - output, next_offset, err := filterData(reader_at, request) - if err != nil { - returnError(w, 500, err.Error()) - return + total_size = calculateTotalSizeWithPadding(index) + } else { + total_size = calculateTotalReaderSize(file) } - w.Header().Set("Content-Disposition", "attachment; "+ - sanitizeFilenameForAttachment(filename)) - w.Header().Set("Content-Type", - utils.GetMimeString(output, utils.AutoDetectMime(request.DetectMime))) - w.Header().Set("Content-Range", - fmt.Sprintf("bytes %d-%d/%d", request.Offset, next_offset, total_size)) - w.WriteHeader(200) + if request.TextFilter { + output, next_offset, err := filterData(reader_at, request) + if err != nil { + returnError(w, 500, err.Error()) + return + } - _, _ = w.Write(output) - return - } + w.Header().Set("Content-Disposition", "attachment; "+ + sanitizeFilenameForAttachment(filename)) + w.Header().Set("Content-Type", + utils.GetMimeString(output, utils.AutoDetectMime(request.DetectMime))) + w.Header().Set("Content-Range", + fmt.Sprintf("bytes %d-%d/%d", request.Offset, next_offset, total_size)) + w.WriteHeader(200) - // If the user requested the whole file, and also has password - // set we send them a zip file with the entire thing - if request.ZipFile { - err = streamZipFile(r.Context(), org_config_obj, w, file, filename) - if err == nil { + _, _ = w.Write(output) return } - } - emitContentLength(w, int(request.Offset), int(request.Length), total_size) + // If the user requested the whole file, and also has password + // set we send them a zip file with the entire thing + if request.ZipFile { + err = streamZipFile(r.Context(), org_config_obj, w, file, filename) + if err == nil { + return + } + } - offset := request.Offset + emitContentLength(w, int(request.Offset), int(request.Length), total_size) - // Read the first buffer now so we can report errors - length_sent := 0 - headers_sent := false + offset := request.Offset - // Only allow limited size buffers to be requested by the user. - var buf []byte - if request.Length == 0 || request.Length >= BUFSIZE { - buf = pool.Get().([]byte) - defer pool.Put(buf) + // Read the first buffer now so we can report errors + length_sent := 0 + headers_sent := false - } else { - buf = make([]byte, request.Length) - } + // Only allow limited size buffers to be requested by the user. + var buf []byte + if request.Length == 0 || request.Length >= BUFSIZE { + buf = pool.Get().([]byte) + defer pool.Put(buf) + + } else { + buf = make([]byte, request.Length) + } - for { - n, err := reader_at.ReadAt(buf, offset) - if err != nil && err != io.EOF { - // Only send errors if the headers have not yet been - // sent. + for { + n, err := reader_at.ReadAt(buf, offset) + if err != nil && err != io.EOF { + // Only send errors if the headers have not yet been + // sent. + if !headers_sent { + returnError(w, 500, err.Error()) + headers_sent = true + } + return + } + if request.Length != 0 { + length_to_send := request.Length - length_sent + if n > length_to_send { + n = length_to_send + } + } + if n <= 0 { + return + } + + // Write an ok status which includes the attachment name + // but only if no other data was sent. if !headers_sent { - returnError(w, 500, err.Error()) + w.Header().Set("Content-Disposition", "attachment; "+ + sanitizeFilenameForAttachment(filename)) + w.Header().Set("Content-Type", + utils.GetMimeString(buf[:n], + utils.AutoDetectMime(request.DetectMime))) + w.WriteHeader(200) headers_sent = true } - return - } - if request.Length != 0 { - length_to_send := request.Length - length_sent - if n > length_to_send { - n = length_to_send - } - } - if n <= 0 { - return - } - // Write an ok status which includes the attachment name - // but only if no other data was sent. - if !headers_sent { - w.Header().Set("Content-Disposition", "attachment; "+ - sanitizeFilenameForAttachment(filename)) - w.Header().Set("Content-Type", - utils.GetMimeString(buf[:n], - utils.AutoDetectMime(request.DetectMime))) - w.WriteHeader(200) - headers_sent = true - } + written, err := w.Write(buf[:n]) + if err != nil { + return + } - written, err := w.Write(buf[:n]) - if err != nil { - return + length_sent += written + offset += int64(n) } - - length_sent += written - offset += int64(n) - } - }) + }) } // Read data from offset and filter it until the requested number of @@ -444,82 +446,83 @@ func getTransformer( } func downloadFileStore(prefix []string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path_spec := paths.FSPathSpecFromClientPath(r.URL.Path) - components := path_spec.Components() - - // make sure the prefix is correct - for i, p := range prefix { - if len(components) <= i || p != components[i] { - returnError(w, 404, "Not Found") - return + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + path_spec := paths.FSPathSpecFromClientPath(r.URL.Path) + components := path_spec.Components() + + // make sure the prefix is correct + for i, p := range prefix { + if len(components) <= i || p != components[i] { + returnError(w, 404, "Not Found") + return + } } - } - org_id := authenticators.GetOrgIdFromRequest(r) - org_manager, err := services.GetOrgManager() - if err != nil { - returnError(w, 404, err.Error()) - return - } + org_id := authenticators.GetOrgIdFromRequest(r) + org_manager, err := services.GetOrgManager() + if err != nil { + returnError(w, 404, err.Error()) + return + } - org_config_obj, err := org_manager.GetOrgConfig(org_id) - if err != nil { - returnError(w, 404, err.Error()) - return - } + org_config_obj, err := org_manager.GetOrgConfig(org_id) + if err != nil { + returnError(w, 404, err.Error()) + return + } - // The following is not strictly necessary because this - // function is behind the authenticator middleware which means - // that if we get here the user is already authenticated and - // has at least read permissions on this org. But we check - // again to make sure we are resilient against possible - // regressions in the authenticator code. - users := services.GetUserManager() - user_record, err := users.GetUserFromHTTPContext(r.Context()) - if err != nil { - returnError(w, 404, err.Error()) - return - } + // The following is not strictly necessary because this + // function is behind the authenticator middleware which means + // that if we get here the user is already authenticated and + // has at least read permissions on this org. But we check + // again to make sure we are resilient against possible + // regressions in the authenticator code. + users := services.GetUserManager() + user_record, err := users.GetUserFromHTTPContext(r.Context()) + if err != nil { + returnError(w, 404, err.Error()) + return + } - principal := user_record.Name - permissions := acls.READ_RESULTS - perm, err := services.CheckAccess(org_config_obj, principal, permissions) - if !perm || err != nil { - returnError(w, 403, "User is not allowed to read files.") - return - } + principal := user_record.Name + permissions := acls.READ_RESULTS + perm, err := services.CheckAccess(org_config_obj, principal, permissions) + if !perm || err != nil { + returnError(w, 403, "User is not allowed to read files.") + return + } - file_store_factory := file_store.GetFileStore(org_config_obj) - fd, err := file_store_factory.ReadFile(path_spec) - if err != nil { - returnError(w, 404, err.Error()) - return - } + file_store_factory := file_store.GetFileStore(org_config_obj) + fd, err := file_store_factory.ReadFile(path_spec) + if err != nil { + returnError(w, 404, err.Error()) + return + } - buf := pool.Get().([]byte) - defer pool.Put(buf) + buf := pool.Get().([]byte) + defer pool.Put(buf) - // Read the first buffer for mime detection. - n, err := fd.Read(buf) - if err != nil { - returnError(w, 404, err.Error()) - return - } + // Read the first buffer for mime detection. + n, err := fd.Read(buf) + if err != nil { + returnError(w, 404, err.Error()) + return + } - // From here on we already sent the headers and we can - // not really report an error to the client. - w.Header().Set("Content-Disposition", "attachment; "+ - sanitizePathspecForAttachment(path_spec)) + // From here on we already sent the headers and we can + // not really report an error to the client. + w.Header().Set("Content-Disposition", "attachment; "+ + sanitizePathspecForAttachment(path_spec)) - w.Header().Set("Content-Type", - utils.GetMimeString(buf[:n], utils.AutoDetectMime(true))) - w.WriteHeader(200) - w.Write(buf[:n]) + w.Header().Set("Content-Type", + utils.GetMimeString(buf[:n], utils.AutoDetectMime(true))) + w.WriteHeader(200) + w.Write(buf[:n]) - // Copy the rest directly. - utils.Copy(r.Context(), w, fd) - }) + // Copy the rest directly. + utils.Copy(r.Context(), w, fd) + }) } // Allowed chars in non extended names @@ -557,124 +560,125 @@ func sanitizeFilenameForAttachment(base_filename string) string { // Download the table as specified by the v1/GetTable API. func downloadTable() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - request := &api_proto.GetTableRequest{} - decoder := schema.NewDecoder() - decoder.IgnoreUnknownKeys(true) - - decoder.SetAliasTag("json") - err := decoder.Decode(request, r.URL.Query()) - if err != nil { - returnError(w, 404, err.Error()) - return - } - - org_manager, err := services.GetOrgManager() - if err != nil { - returnError(w, 404, err.Error()) - return - } - - org_config_obj, err := org_manager.GetOrgConfig(request.OrgId) - if err != nil { - returnError(w, 404, err.Error()) - return - } - - row_chan, closer, log_path, err := getRows( - r.Context(), org_config_obj, request) - if err != nil { - returnError(w, 400, "Invalid request") - return - } - defer closer() - - transform := getTransformer(r.Context(), org_config_obj, request) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + request := &api_proto.GetTableRequest{} + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + decoder.SetAliasTag("json") + err := decoder.Decode(request, r.URL.Query()) + if err != nil { + returnError(w, 404, err.Error()) + return + } - download_name := request.DownloadFilename - if download_name == "" { - download_name = strings.Replace(log_path.Base(), "\"", "", -1) - } + org_manager, err := services.GetOrgManager() + if err != nil { + returnError(w, 404, err.Error()) + return + } - // Log an audit event. - user_record := GetUserInfo(r.Context(), org_config_obj) - principal := user_record.Name + org_config_obj, err := org_manager.GetOrgConfig(request.OrgId) + if err != nil { + returnError(w, 404, err.Error()) + return + } - // This should never happen! - if principal == "" { - returnError(w, 403, "Unauthenticated access.") - return - } + row_chan, closer, log_path, err := getRows( + r.Context(), org_config_obj, request) + if err != nil { + returnError(w, 400, "Invalid request") + return + } + defer closer() - permissions := acls.READ_RESULTS - perm, err := services.CheckAccess(org_config_obj, principal, permissions) - if !perm || err != nil { - returnError(w, 403, "Unauthenticated access.") - return - } + transform := getTransformer(r.Context(), org_config_obj, request) - opts := json.GetJsonOptsForTimezone(request.Timezone) - switch request.DownloadFormat { - case "csv": - download_name = strings.TrimSuffix(download_name, ".json") - download_name += ".csv" + download_name := request.DownloadFilename + if download_name == "" { + download_name = strings.Replace(log_path.Base(), "\"", "", -1) + } - // From here on we already sent the headers and we can - // not really report an error to the client. - w.Header().Set("Content-Disposition", "attachment; "+ - sanitizeFilenameForAttachment(download_name)) - w.Header().Set("Content-Type", "binary/octet-stream") - w.WriteHeader(200) + // Log an audit event. + user_record := GetUserInfo(r.Context(), org_config_obj) + principal := user_record.Name - services.LogAudit(r.Context(), - org_config_obj, principal, "DownloadTable", - ordereddict.NewDict(). - Set("request", request). - Set("remote", r.RemoteAddr)) - - scope := vql_subsystem.MakeScope() - csv_writer := csv.GetCSVAppender( - org_config_obj, scope, w, - csv.WriteHeaders, opts) - for row := range row_chan { - csv_writer.Write( - filterColumns(request.Columns, transform(row))) + // This should never happen! + if principal == "" { + returnError(w, 403, "Unauthenticated access.") + return } - csv_writer.Close() - // Output in jsonl by default. - default: - if !strings.HasSuffix(download_name, ".json") { - download_name += ".json" + permissions := acls.READ_RESULTS + perm, err := services.CheckAccess(org_config_obj, principal, permissions) + if !perm || err != nil { + returnError(w, 403, "Unauthenticated access.") + return } - // From here on we already sent the headers and we can - // not really report an error to the client. - w.Header().Set("Content-Disposition", "attachment; "+ - sanitizeFilenameForAttachment(download_name)) - w.Header().Set("Content-Type", "binary/octet-stream") - w.WriteHeader(200) + opts := json.GetJsonOptsForTimezone(request.Timezone) + switch request.DownloadFormat { + case "csv": + download_name = strings.TrimSuffix(download_name, ".json") + download_name += ".csv" - services.LogAudit(r.Context(), - org_config_obj, principal, "DownloadTable", - ordereddict.NewDict(). - Set("request", request). - Set("remote", r.RemoteAddr)) + // From here on we already sent the headers and we can + // not really report an error to the client. + w.Header().Set("Content-Disposition", "attachment; "+ + sanitizeFilenameForAttachment(download_name)) + w.Header().Set("Content-Type", "binary/octet-stream") + w.WriteHeader(200) - for row := range row_chan { - serialized, err := json.MarshalWithOptions( - filterColumns(request.Columns, transform(row)), - json.GetJsonOptsForTimezone(request.Timezone)) - if err != nil { - return + services.LogAudit(r.Context(), + org_config_obj, principal, "DownloadTable", + ordereddict.NewDict(). + Set("request", request). + Set("remote", r.RemoteAddr)) + + scope := vql_subsystem.MakeScope() + csv_writer := csv.GetCSVAppender( + org_config_obj, scope, w, + csv.WriteHeaders, opts) + for row := range row_chan { + csv_writer.Write( + filterColumns(request.Columns, transform(row))) + } + csv_writer.Close() + + // Output in jsonl by default. + default: + if !strings.HasSuffix(download_name, ".json") { + download_name += ".json" } - // Write line delimited JSON - _, _ = w.Write(serialized) - _, _ = w.Write([]byte{'\n'}) + // From here on we already sent the headers and we can + // not really report an error to the client. + w.Header().Set("Content-Disposition", "attachment; "+ + sanitizeFilenameForAttachment(download_name)) + w.Header().Set("Content-Type", "binary/octet-stream") + w.WriteHeader(200) + + services.LogAudit(r.Context(), + org_config_obj, principal, "DownloadTable", + ordereddict.NewDict(). + Set("request", request). + Set("remote", r.RemoteAddr)) + + for row := range row_chan { + serialized, err := json.MarshalWithOptions( + filterColumns(request.Columns, transform(row)), + json.GetJsonOptsForTimezone(request.Timezone)) + if err != nil { + return + } + + // Write line delimited JSON + _, _ = w.Write(serialized) + _, _ = w.Write([]byte{'\n'}) + } } - } - }) + }) } func vfsGetBuffer( diff --git a/api/fixtures/TestBasicAuthenticator.golden b/api/fixtures/TestBasicAuthenticator.golden new file mode 100644 index 00000000000..41407077493 --- /dev/null +++ b/api/fixtures/TestBasicAuthenticator.golden @@ -0,0 +1,84 @@ +{ + "Mux": { + "/favicon.png": [ + "*http.redirectHandler" + ], + "/velociraptor/": [ + "api.PrepareGUIMux" + ], + "/velociraptor/api/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " *utils.ServeMux" + ], + "/velociraptor/api/v1/DownloadTable": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " api.downloadTable" + ], + "/velociraptor/api/v1/DownloadVFSFile": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " api.vfsFileDownloadHandler" + ], + "/velociraptor/api/v1/UploadFormFile": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " api.formUploadHandler" + ], + "/velociraptor/api/v1/UploadTool": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " api.toolUploadHandler" + ], + "/velociraptor/app/": [ + "authenticators.IpFilter", + " utils.StripPrefix", + " NewInterceptingResponseWriter", + " *gzipped.fileHandler" + ], + "/velociraptor/app/index.html": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " api.GetTemplateHandler" + ], + "/velociraptor/app/logoff.html": [ + "authenticators.IpFilter", + " authenticators.(*BasicAuthenticator).AddLogoff" + ], + "/velociraptor/clients/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/downloads/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/hunts/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/notebooks/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.(*BasicAuthenticator).AuthenticateUserHandler", + " utils.StripPrefix", + " api.downloadFileStore" + ] + } +} \ No newline at end of file diff --git a/api/fixtures/TestMultiAuthenticator.golden b/api/fixtures/TestMultiAuthenticator.golden new file mode 100644 index 00000000000..1d2dbc5cef7 --- /dev/null +++ b/api/fixtures/TestMultiAuthenticator.golden @@ -0,0 +1,133 @@ +{ + "Redirect Provider *authenticators.OidcAuthenticator": "https://www.example.com/velociraptor/auth/oidc/callback", + "Redirect Provider *authenticators.GoogleAuthenticator": "https://www.example.com/velociraptor/auth/google/callback", + "Redirect Provider *authenticators.GitHubAuthenticator": "https://www.example.com/velociraptor/auth/github/callback", + "Redirect Provider *authenticators.OidcAuthenticatorCognito": "https://www.example.com/velociraptor/auth/oidc/cognito/callback", + "Redirect Provider *authenticators.AzureAuthenticator": "https://www.example.com/velociraptor/auth/azure/callback", + "Mux": { + "/favicon.png": [ + "*http.redirectHandler" + ], + "/velociraptor/": [ + "api.PrepareGUIMux" + ], + "/velociraptor/api/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " *utils.ServeMux" + ], + "/velociraptor/api/v1/DownloadTable": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " api.downloadTable" + ], + "/velociraptor/api/v1/DownloadVFSFile": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " api.vfsFileDownloadHandler" + ], + "/velociraptor/api/v1/UploadFormFile": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " api.formUploadHandler" + ], + "/velociraptor/api/v1/UploadTool": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " api.toolUploadHandler" + ], + "/velociraptor/app/": [ + "authenticators.IpFilter", + " utils.StripPrefix", + " NewInterceptingResponseWriter", + " *gzipped.fileHandler" + ], + "/velociraptor/app/index.html": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " api.GetTemplateHandler" + ], + "/velociraptor/app/logoff.html": [ + "authenticators.IpFilter", + " authenticators.installLogoff" + ], + "/velociraptor/auth/azure/callback": [ + "authenticators.IpFilter", + " authenticators.(*AzureAuthenticator).oauthAzureCallback" + ], + "/velociraptor/auth/azure/login": [ + "authenticators.IpFilter", + " authenticators.(*AzureAuthenticator).oauthAzureLogin" + ], + "/velociraptor/auth/azure/picture": [ + "authenticators.IpFilter", + " authenticators.(*AzureAuthenticator).oauthAzurePicture" + ], + "/velociraptor/auth/github/callback": [ + "authenticators.IpFilter", + " authenticators.(*GitHubAuthenticator).oauthGithubCallback" + ], + "/velociraptor/auth/github/login": [ + "authenticators.IpFilter", + " authenticators.(*GitHubAuthenticator).oauthGithubLogin" + ], + "/velociraptor/auth/google/callback": [ + "authenticators.IpFilter", + " authenticators.(*GoogleAuthenticator).oauthGoogleCallback" + ], + "/velociraptor/auth/google/login": [ + "authenticators.IpFilter", + " authenticators.(*GoogleAuthenticator).oauthGoogleLogin" + ], + "/velociraptor/auth/oidc/callback": [ + "authenticators.IpFilter", + " authenticators.(*OidcAuthenticator).oauthOidcCallback" + ], + "/velociraptor/auth/oidc/cognito/callback": [ + "authenticators.IpFilter", + " authenticators.(*OidcAuthenticatorCognito).oauthOidcCallback" + ], + "/velociraptor/auth/oidc/cognito/login": [ + "authenticators.IpFilter", + " authenticators.(*OidcAuthenticator).oauthOidcLogin" + ], + "/velociraptor/auth/oidc/login": [ + "authenticators.IpFilter", + " authenticators.(*OidcAuthenticator).oauthOidcLogin" + ], + "/velociraptor/clients/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/downloads/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/hunts/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " utils.StripPrefix", + " api.downloadFileStore" + ], + "/velociraptor/notebooks/": [ + "authenticators.IpFilter", + " api.csrfProtect", + " authenticators.authenticateUserHandle", + " utils.StripPrefix", + " api.downloadFileStore" + ] + } +} \ No newline at end of file diff --git a/api/proxy.go b/api/proxy.go index 2001ab33e93..bc78c705efc 100644 --- a/api/proxy.go +++ b/api/proxy.go @@ -41,12 +41,11 @@ import ( crypto_utils "www.velocidex.com/golang/velociraptor/crypto/utils" "www.velocidex.com/golang/velociraptor/grpc_client" "www.velocidex.com/golang/velociraptor/logging" - "www.velocidex.com/golang/velociraptor/server" "www.velocidex.com/golang/velociraptor/utils" ) // A Mux for the reverse proxy feature. -func AddProxyMux(config_obj *config_proto.Config, mux *http.ServeMux) error { +func AddProxyMux(config_obj *config_proto.Config, mux *api_utils.ServeMux) error { if config_obj.GUI == nil { return errors.New("GUI not configured") } @@ -64,28 +63,29 @@ func AddProxyMux(config_obj *config_proto.Config, mux *http.ServeMux) error { var handler http.Handler if target.Scheme == "file" { - handler = http.StripPrefix(reverse_proxy_config.Route, + handler = api_utils.StripPrefix(reverse_proxy_config.Route, http.FileServer(http.Dir(target.Path))) } else { - handler = http.StripPrefix(reverse_proxy_config.Route, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Host = target.Host - r.URL.Scheme = target.Scheme - r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) - r.Host = target.Host - - // If we require auth we do - // not pass the auth header to - // the target of the - // proxy. Otherwise we leave - // authentication to it. - if reverse_proxy_config.RequireAuth { - r.Header.Del("Authorization") - } - - httputil.NewSingleHostReverseProxy(target).ServeHTTP(w, r) - })) + handler = api_utils.StripPrefix(reverse_proxy_config.Route, + api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + r.URL.Host = target.Host + r.URL.Scheme = target.Scheme + r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) + r.Host = target.Host + + // If we require auth we do + // not pass the auth header to + // the target of the + // proxy. Otherwise we leave + // authentication to it. + if reverse_proxy_config.RequireAuth { + r.Header.Del("Authorization") + } + + httputil.NewSingleHostReverseProxy(target).ServeHTTP(w, r) + })) } if reverse_proxy_config.RequireAuth { @@ -106,8 +106,7 @@ func AddProxyMux(config_obj *config_proto.Config, mux *http.ServeMux) error { func PrepareGUIMux( ctx context.Context, config_obj *config_proto.Config, - server_obj *server.Server, - mux *http.ServeMux) (http.Handler, error) { + mux *api_utils.ServeMux) (http.Handler, error) { if config_obj.GUI == nil { return nil, errors.New("GUI not configured") } @@ -127,7 +126,8 @@ func PrepareGUIMux( return nil, err } if config_obj.GUI != nil && config_obj.GUI.Authenticator != nil { - server_obj.Info("GUI will use the %v authenticator", config_obj.GUI.Authenticator.Type) + logger := logging.GetLogger(config_obj, &logging.GUIComponent) + logger.Info("GUI will use the %v authenticator", config_obj.GUI.Authenticator.Type) } // Add the authenticator specific handlers. @@ -142,53 +142,53 @@ func PrepareGUIMux( return nil, err } - base := api_utils.GetBasePath(config_obj) - mux.Handle(api_utils.Join(base, "/api/"), ipFilter(config_obj, - csrfProtect(config_obj, - auther.AuthenticateUserHandler(h)))) + mux.Handle(api_utils.GetBasePath(config_obj, "/api/"), + ipFilter(config_obj, + csrfProtect(config_obj, + auther.AuthenticateUserHandler(h)))) - mux.Handle(api_utils.Join(base, "/api/v1/DownloadTable"), + mux.Handle(api_utils.GetBasePath(config_obj, "/api/v1/DownloadTable"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler(downloadTable())))) - mux.Handle(api_utils.Join(base, "/api/v1/DownloadVFSFile"), + mux.Handle(api_utils.GetBasePath(config_obj, "/api/v1/DownloadVFSFile"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler(vfsFileDownloadHandler())))) - mux.Handle(api_utils.Join(base, "/api/v1/UploadTool"), + mux.Handle(api_utils.GetBasePath(config_obj, "/api/v1/UploadTool"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler(toolUploadHandler())))) - mux.Handle(api_utils.Join(base, "/api/v1/UploadFormFile"), + mux.Handle(api_utils.GetBasePath(config_obj, "/api/v1/UploadFormFile"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler(formUploadHandler())))) // Serve prepared zip files. - mux.Handle(api_utils.Join(base, "/downloads/"), + mux.Handle(api_utils.GetBasePath(config_obj, "/downloads/"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler( - http.StripPrefix(base, + api_utils.StripPrefix(api_utils.GetBasePath(config_obj), downloadFileStore([]string{"downloads"})))))) // Serve notebook items - mux.Handle(api_utils.Join(base, "/notebooks/"), + mux.Handle(api_utils.GetBasePath(config_obj, "/notebooks/"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler( - http.StripPrefix(base, + api_utils.StripPrefix(api_utils.GetBasePath(config_obj), downloadFileStore([]string{"notebooks"})))))) // Serve files from hunt notebooks - mux.Handle(api_utils.Join(base, "/hunts/"), + mux.Handle(api_utils.GetBasePath(config_obj, "/hunts/"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler( - http.StripPrefix(base, + api_utils.StripPrefix(api_utils.GetBasePath(config_obj), downloadFileStore([]string{"hunts"})))))) // Serve files from client notebooks - mux.Handle(api_utils.Join(base, "/clients/"), + mux.Handle(api_utils.GetBasePath(config_obj, "/clients/"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler( - http.StripPrefix(base, + api_utils.StripPrefix(api_utils.GetBasePath(config_obj), downloadFileStore([]string{"clients"})))))) // Assets etc do not need auth. @@ -204,15 +204,18 @@ func PrepareGUIMux( if err != nil { return nil, err } - mux.Handle(api_utils.Join(base, "/app/index.html"), + mux.Handle(api_utils.GetBasePath(config_obj, "/app/index.html"), ipFilter(config_obj, csrfProtect(config_obj, auther.AuthenticateUserHandler(h)))) // Redirect everything else to the app mux.Handle(api_utils.GetBaseDirectory(config_obj), - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, api_utils.Join(base, "/app/index.html"), 302) - })) + api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, + api_utils.GetBasePath(config_obj, "/app/index.html"), + http.StatusTemporaryRedirect) + })) return mux, nil } @@ -310,10 +313,10 @@ func GetAPIHandler( return nil, err } - base := api_utils.GetBasePath(config_obj) - reverse_proxy_mux := http.NewServeMux() - reverse_proxy_mux.Handle(api_utils.Join(base, "/api/v1/"), - http.StripPrefix(base, grpc_proxy_mux)) + reverse_proxy_mux := api_utils.NewServeMux() + reverse_proxy_mux.Handle(api_utils.GetBasePath(config_obj, "/api/v1/"), + api_utils.StripPrefix( + api_utils.GetBasePath(config_obj), grpc_proxy_mux)) return reverse_proxy_mux, nil } diff --git a/api/proxy_test.go b/api/proxy_test.go new file mode 100644 index 00000000000..758fb9d4253 --- /dev/null +++ b/api/proxy_test.go @@ -0,0 +1,107 @@ +package api + +import ( + "fmt" + "testing" + + "github.com/Velocidex/ordereddict" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" + "www.velocidex.com/golang/velociraptor/api/authenticators" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" + config_proto "www.velocidex.com/golang/velociraptor/config/proto" + "www.velocidex.com/golang/velociraptor/file_store/test_utils" + "www.velocidex.com/golang/velociraptor/json" + "www.velocidex.com/golang/velociraptor/vtesting/assert" + "www.velocidex.com/golang/velociraptor/vtesting/goldie" +) + +type APIProxyTestSuite struct { + test_utils.TestSuite +} + +func (self *APIProxyTestSuite) TestMultiAuthenticator() { + mux := api_utils.NewServeMux() + + config_obj := proto.Clone(self.ConfigObj).(*config_proto.Config) + config_obj.GUI.PublicUrl = "https://www.example.com/" + config_obj.GUI.BasePath = "/velociraptor" + config_obj.GUI.Authenticator = &config_proto.Authenticator{ + Type: "multi", + SubAuthenticators: []*config_proto.Authenticator{{ + Type: "oidc", + OidcIssuer: "https://accounts.google.com", + OauthClientId: "CCCCC", + OauthClientSecret: "secret", + }, { + Type: "Google", + OauthClientId: "CCCCC", + OauthClientSecret: "secret", + }, { + Type: "GitHub", + OauthClientId: "CCCCC", + OauthClientSecret: "secret", + }, { + Type: "oidc-cognito", + OidcIssuer: "https://accounts.google.com", + OauthClientId: "CCCCC", + OauthClientSecret: "secret", + OidcName: "cognito", + }, { + Type: "azure", + OauthClientId: "CCCCC", + OauthClientSecret: "secret", + }}, + } + + _, err := PrepareGUIMux(self.Ctx, config_obj, mux) + assert.NoError(self.T(), err) + + auther, err := authenticators.NewAuthenticator(config_obj) + assert.NoError(self.T(), err) + + auther_multi, ok := auther.(*authenticators.MultiAuthenticator) + assert.True(self.T(), ok) + + golden := ordereddict.NewDict() + + for _, delegate := range auther_multi.Delegates() { + auther_oidc, ok := delegate.(authenticators.OIDCConnector) + if !ok { + continue + } + + oidc_config, err := auther_oidc.GetGenOauthConfig() + assert.NoError(self.T(), err) + golden.Set(fmt.Sprintf("Redirect Provider %T", delegate), + oidc_config.RedirectURL) + } + + golden.Set("Mux", mux.Debug()) + + goldie.Assert(self.T(), "TestMultiAuthenticator", json.MustMarshalIndent(golden)) +} + +func (self *APIProxyTestSuite) TestBasicAuthenticator() { + mux := api_utils.NewServeMux() + + config_obj := proto.Clone(self.ConfigObj).(*config_proto.Config) + config_obj.GUI.PublicUrl = "https://www.example.com/" + config_obj.GUI.BasePath = "/velociraptor" + config_obj.GUI.Authenticator = &config_proto.Authenticator{ + Type: "basic", + } + + _, err := PrepareGUIMux(self.Ctx, config_obj, mux) + assert.NoError(self.T(), err) + + golden := ordereddict.NewDict() + + golden.Set("Mux", mux.Debug()) + + goldie.Assert(self.T(), "TestBasicAuthenticator", json.MustMarshalIndent(golden)) +} + +func TestAPIProxy(t *testing.T) { + suite.Run(t, &APIProxyTestSuite{}) +} diff --git a/api/upload.go b/api/upload.go index d9a6c8e8942..91ab4556a9a 100644 --- a/api/upload.go +++ b/api/upload.go @@ -11,6 +11,7 @@ import ( "www.velocidex.com/golang/velociraptor/acls" "www.velocidex.com/golang/velociraptor/api/authenticators" api_proto "www.velocidex.com/golang/velociraptor/api/proto" + api_utils "www.velocidex.com/golang/velociraptor/api/utils" artifacts_proto "www.velocidex.com/golang/velociraptor/artifacts/proto" "www.velocidex.com/golang/velociraptor/json" "www.velocidex.com/golang/velociraptor/logging" @@ -19,230 +20,232 @@ import ( ) func toolUploadHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - org_id := authenticators.GetOrgIdFromRequest(r) - org_manager, err := services.GetOrgManager() - if err != nil { - returnError(w, http.StatusUnauthorized, err.Error()) - return - } - - org_config_obj, err := org_manager.GetOrgConfig(org_id) - if err != nil { - returnError(w, http.StatusUnauthorized, err.Error()) - return - } - - // Check for acls - userinfo := GetUserInfo(r.Context(), org_config_obj) - permissions := acls.ARTIFACT_WRITER - perm, err := services.CheckAccess(org_config_obj, userinfo.Name, permissions) - if !perm || err != nil { - returnError(w, http.StatusUnauthorized, - "User is not allowed to upload tools.") - return - } - - // Parse our multipart form, 10 << 20 specifies a maximum - // upload of 10 MB files. - err = r.ParseMultipartForm(10 << 25) - if err != nil { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - defer r.MultipartForm.RemoveAll() - - tool := &artifacts_proto.Tool{} - params, pres := r.Form["_params_"] - if !pres || len(params) != 1 { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - - err = json.Unmarshal([]byte(params[0]), tool) - if err != nil { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - - // FormFile returns the first file for the given key `myFile` - // it also returns the FileHeader so we can get the Filename, - // the Header and the size of the file - file, handler, err := r.FormFile("file") - if err != nil { - returnError(w, 403, fmt.Sprintf("Unsupported params: %v", err)) - return - } - defer file.Close() - - tool.Filename = path.Base(handler.Filename) - tool.ServeLocally = true - - path_manager := paths.NewInventoryPathManager(org_config_obj, tool) - pathspec, file_store_factory, err := path_manager.Path() - if err != nil { - returnError(w, 404, err.Error()) - } - - writer, err := file_store_factory.WriteFile(pathspec) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - defer writer.Close() - - err = writer.Truncate() - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - sha_sum := sha256.New() - - _, err = io.Copy(writer, io.TeeReader(file, sha_sum)) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - tool.Hash = hex.EncodeToString(sha_sum.Sum(nil)) - - inventory, err := services.GetInventory(org_config_obj) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - ctx := r.Context() - err = inventory.AddTool(ctx, org_config_obj, tool, - services.ToolOptions{ - AdminOverride: true, - }) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - // Now materialize the tool - tool, err = inventory.GetToolInfo( - r.Context(), org_config_obj, tool.Name, tool.Version) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - serialized, _ := json.Marshal(tool) - _, err = w.Write(serialized) - if err != nil { - logger := logging.GetLogger(org_config_obj, &logging.GUIComponent) - logger.Error("toolUploadHandler: %v", err) - } - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + org_id := authenticators.GetOrgIdFromRequest(r) + org_manager, err := services.GetOrgManager() + if err != nil { + returnError(w, http.StatusUnauthorized, err.Error()) + return + } + + org_config_obj, err := org_manager.GetOrgConfig(org_id) + if err != nil { + returnError(w, http.StatusUnauthorized, err.Error()) + return + } + + // Check for acls + userinfo := GetUserInfo(r.Context(), org_config_obj) + permissions := acls.ARTIFACT_WRITER + perm, err := services.CheckAccess(org_config_obj, userinfo.Name, permissions) + if !perm || err != nil { + returnError(w, http.StatusUnauthorized, + "User is not allowed to upload tools.") + return + } + + // Parse our multipart form, 10 << 20 specifies a maximum + // upload of 10 MB files. + err = r.ParseMultipartForm(10 << 25) + if err != nil { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + defer r.MultipartForm.RemoveAll() + + tool := &artifacts_proto.Tool{} + params, pres := r.Form["_params_"] + if !pres || len(params) != 1 { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + + err = json.Unmarshal([]byte(params[0]), tool) + if err != nil { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + + // FormFile returns the first file for the given key `myFile` + // it also returns the FileHeader so we can get the Filename, + // the Header and the size of the file + file, handler, err := r.FormFile("file") + if err != nil { + returnError(w, 403, fmt.Sprintf("Unsupported params: %v", err)) + return + } + defer file.Close() + + tool.Filename = path.Base(handler.Filename) + tool.ServeLocally = true + + path_manager := paths.NewInventoryPathManager(org_config_obj, tool) + pathspec, file_store_factory, err := path_manager.Path() + if err != nil { + returnError(w, 404, err.Error()) + } + + writer, err := file_store_factory.WriteFile(pathspec) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + defer writer.Close() + + err = writer.Truncate() + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + sha_sum := sha256.New() + + _, err = io.Copy(writer, io.TeeReader(file, sha_sum)) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + tool.Hash = hex.EncodeToString(sha_sum.Sum(nil)) + + inventory, err := services.GetInventory(org_config_obj) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + ctx := r.Context() + err = inventory.AddTool(ctx, org_config_obj, tool, + services.ToolOptions{ + AdminOverride: true, + }) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + // Now materialize the tool + tool, err = inventory.GetToolInfo( + r.Context(), org_config_obj, tool.Name, tool.Version) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + serialized, _ := json.Marshal(tool) + _, err = w.Write(serialized) + if err != nil { + logger := logging.GetLogger(org_config_obj, &logging.GUIComponent) + logger.Error("toolUploadHandler: %v", err) + } + }) } func formUploadHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - org_id := authenticators.GetOrgIdFromRequest(r) - org_manager, err := services.GetOrgManager() - if err != nil { - returnError(w, http.StatusUnauthorized, err.Error()) - return - } - - org_config_obj, err := org_manager.GetOrgConfig(org_id) - if err != nil { - returnError(w, http.StatusUnauthorized, err.Error()) - return - } - - // Check for acls - userinfo := GetUserInfo(r.Context(), org_config_obj) - permissions := acls.COLLECT_CLIENT - perm, err := services.CheckAccess(org_config_obj, userinfo.Name, permissions) - if !perm || err != nil { - returnError(w, http.StatusUnauthorized, - "User is not allowed to upload files for forms.") - return - } - - // Parse our multipart form, 10 << 20 specifies a maximum - // upload of 10 MB files. - err = r.ParseMultipartForm(10 << 20) - if err != nil { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - defer r.MultipartForm.RemoveAll() - - form_desc := &api_proto.FormUploadMetadata{} - params, pres := r.Form["_params_"] - if !pres || len(params) != 1 { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - - err = json.Unmarshal([]byte(params[0]), form_desc) - if err != nil { - returnError(w, http.StatusBadRequest, "Unsupported params") - return - } - - // FormFile returns the first file for the given key `file` - // it also returns the FileHeader so we can get the Filename, - // the Header and the size of the file - file, handler, err := r.FormFile("file") - if err != nil { - returnError(w, 403, fmt.Sprintf("Unsupported params: %v", err)) - return - } - defer file.Close() - - form_desc.Filename = path.Base(handler.Filename) - - path_manager := paths.NewFormUploadPathManager( - org_config_obj, form_desc.Filename) - - pathspec, file_store_factory, err := path_manager.Path() - if err != nil { - returnError(w, 403, fmt.Sprintf("Error: %v", err)) - return - } - - form_desc.Url = path_manager.URL() - - writer, err := file_store_factory.WriteFile(pathspec) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - defer writer.Close() - - err = writer.Truncate() - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - _, err = io.Copy(writer, file) - if err != nil { - returnError(w, http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err)) - return - } - - serialized, _ := json.Marshal(form_desc) - _, err = w.Write(serialized) - if err != nil { - logger := logging.GetLogger(org_config_obj, &logging.GUIComponent) - logger.Error("toolUploadHandler: %v", err) - } - }) + return api_utils.HandlerFunc(nil, + func(w http.ResponseWriter, r *http.Request) { + org_id := authenticators.GetOrgIdFromRequest(r) + org_manager, err := services.GetOrgManager() + if err != nil { + returnError(w, http.StatusUnauthorized, err.Error()) + return + } + + org_config_obj, err := org_manager.GetOrgConfig(org_id) + if err != nil { + returnError(w, http.StatusUnauthorized, err.Error()) + return + } + + // Check for acls + userinfo := GetUserInfo(r.Context(), org_config_obj) + permissions := acls.COLLECT_CLIENT + perm, err := services.CheckAccess(org_config_obj, userinfo.Name, permissions) + if !perm || err != nil { + returnError(w, http.StatusUnauthorized, + "User is not allowed to upload files for forms.") + return + } + + // Parse our multipart form, 10 << 20 specifies a maximum + // upload of 10 MB files. + err = r.ParseMultipartForm(10 << 20) + if err != nil { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + defer r.MultipartForm.RemoveAll() + + form_desc := &api_proto.FormUploadMetadata{} + params, pres := r.Form["_params_"] + if !pres || len(params) != 1 { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + + err = json.Unmarshal([]byte(params[0]), form_desc) + if err != nil { + returnError(w, http.StatusBadRequest, "Unsupported params") + return + } + + // FormFile returns the first file for the given key `file` + // it also returns the FileHeader so we can get the Filename, + // the Header and the size of the file + file, handler, err := r.FormFile("file") + if err != nil { + returnError(w, 403, fmt.Sprintf("Unsupported params: %v", err)) + return + } + defer file.Close() + + form_desc.Filename = path.Base(handler.Filename) + + path_manager := paths.NewFormUploadPathManager( + org_config_obj, form_desc.Filename) + + pathspec, file_store_factory, err := path_manager.Path() + if err != nil { + returnError(w, 403, fmt.Sprintf("Error: %v", err)) + return + } + + form_desc.Url = path_manager.URL() + + writer, err := file_store_factory.WriteFile(pathspec) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + defer writer.Close() + + err = writer.Truncate() + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + _, err = io.Copy(writer, file) + if err != nil { + returnError(w, http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err)) + return + } + + serialized, _ := json.Marshal(form_desc) + _, err = w.Write(serialized) + if err != nil { + logger := logging.GetLogger(org_config_obj, &logging.GUIComponent) + logger.Error("toolUploadHandler: %v", err) + } + }) } diff --git a/api/utils/mux.go b/api/utils/mux.go new file mode 100644 index 00000000000..c153d8d0332 --- /dev/null +++ b/api/utils/mux.go @@ -0,0 +1,114 @@ +package utils + +import ( + "fmt" + "net/http" + "path/filepath" + "runtime" + "sort" + "strings" + + "github.com/Velocidex/ordereddict" +) + +type Stringer interface { + String() string +} + +type ServeMux struct { + *http.ServeMux + + Handlers map[string]http.Handler +} + +func (self *ServeMux) Handle(pattern string, handler http.Handler) { + self.Handlers[pattern] = handler + self.ServeMux.Handle(pattern, handler) +} + +func (self *ServeMux) Debug() *ordereddict.Dict { + res := ordereddict.NewDict() + var keys []string + for k := range self.Handlers { + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, k := range keys { + v, ok := self.Handlers[k] + if !ok { + continue + } + + name := fmt.Sprintf("%T", v) + stringer, ok := v.(Stringer) + if ok { + name = stringer.String() + } + + parts := strings.Split(name, ":") + res.Set(k, parts) + } + return res +} + +func NewServeMux() *ServeMux { + return &ServeMux{ + ServeMux: http.NewServeMux(), + Handlers: make(map[string]http.Handler), + } +} + +type HandlerFuncContainer struct { + http.HandlerFunc + callSite string + parent *HandlerFuncContainer +} + +func (self *HandlerFuncContainer) String() string { + res := self.callSite + if self.parent != nil { + res += ": " + self.parent.String() + } + + return res +} + +func (self *HandlerFuncContainer) AddChild(note string) *HandlerFuncContainer { + self.callSite = note + return self +} + +func HandlerFunc(parent http.Handler, f http.HandlerFunc) *HandlerFuncContainer { + res := &HandlerFuncContainer{ + HandlerFunc: http.HandlerFunc(f), + } + + if parent != nil { + parent_handler, ok := parent.(*HandlerFuncContainer) + if ok { + res.parent = parent_handler + } else { + res.parent = &HandlerFuncContainer{ + callSite: fmt.Sprintf("%T", parent), + } + } + } + + pc, _, _, ok := runtime.Caller(1) + if ok { + details := runtime.FuncForPC(pc) + if details != nil { + res.callSite = filepath.Base(details.Name()) + } + } + + return res +} + +func StripPrefix(prefix string, h http.Handler) http.Handler { + handler := http.StripPrefix(prefix, h) + + return HandlerFunc(h, handler.ServeHTTP) +} diff --git a/api/utils/utils.go b/api/utils/utils.go index 5ad4f5639a9..d8a81f07410 100644 --- a/api/utils/utils.go +++ b/api/utils/utils.go @@ -4,37 +4,50 @@ import ( "strings" config_proto "www.velocidex.com/golang/velociraptor/config/proto" + "www.velocidex.com/golang/velociraptor/utils" ) // Normalize the base path. If base path is not specified or / return // "". Otherwise ensure base path has a leading / and no following / -func GetBasePath(config_obj *config_proto.Config) string { - if config_obj.GUI == nil || config_obj.GUI.BasePath == "" { - return "" - } +func GetBasePath(config_obj *config_proto.Config, parts ...string) string { + base, _ := utils.GetBaseURL(config_obj) - bare := strings.TrimSuffix(config_obj.GUI.BasePath, "/") - bare = strings.TrimPrefix(bare, "/") - if bare == "" { - return "" - } - return "/" + bare + args := append([]string{base.Path}, parts...) + base.Path = Join(args...) + return base.Path } // Return the base directory (with the trailing /) for the base path func GetBaseDirectory(config_obj *config_proto.Config) string { - return GetBasePath(config_obj) + "/" + base := GetBasePath(config_obj) + return strings.TrimSuffix(base, "/") + "/" } -// Ensure public URL start and ends with / -func GetPublicURL(config_obj *config_proto.Config) string { - bare := strings.TrimSuffix(config_obj.GUI.PublicUrl, "/") - return bare + "/" +// Returns the fully qualified URL to the API endpoint. +func GetPublicURL(config_obj *config_proto.Config, parts ...string) string { + base, err := utils.GetBaseURL(config_obj) + if err != nil { + return "" + } + args := append([]string{base.Path}, parts...) + base.Path = Join(args...) + return base.String() +} + +// Returns the absolute public URL referring to all the parts +func PublicURL(config_obj *config_proto.Config, parts ...string) string { + base, err := utils.GetBaseURL(config_obj) + if err != nil { + return "/" + } + args := append([]string{base.Path}, parts...) + base.Path = Join(args...) + return base.String() } // Join all parts of the URL to make sure that there is only a single // / between them regardless of if they have leading or trailing /. -// Ensure the url starts withv / unless it is an absolute URL starting +// Ensure the url starts with / unless it is an absolute URL starting // with http If the final part ends with / preserve that to refer to a // directory. func Join(parts ...string) string { diff --git a/bin/config_interactive.go b/bin/config_interactive.go index 891234e3324..b36bce3eab1 100644 --- a/bin/config_interactive.go +++ b/bin/config_interactive.go @@ -18,6 +18,7 @@ import ( config_proto "www.velocidex.com/golang/velociraptor/config/proto" logging "www.velocidex.com/golang/velociraptor/logging" "www.velocidex.com/golang/velociraptor/services/users" + "www.velocidex.com/golang/velociraptor/utils" "www.velocidex.com/golang/velociraptor/utils/tempfile" ) @@ -377,19 +378,22 @@ func configureSSO(config_obj *config_proto.Config) error { } // Provide the user with a hint about the redirect URL - redirect := "" + redirect, err := utils.GetBaseURL(config_obj) + if err != nil { + return err + } switch config_obj.GUI.Authenticator.Type { case "Google": - redirect = config_obj.GUI.PublicUrl + "auth/google/callback" + redirect.Path = path.Join(redirect.Path, "auth/google/callback") case "GitHub": - redirect = config_obj.GUI.PublicUrl + "auth/github/callback" + redirect.Path = path.Join(redirect.Path, "auth/github/callback") case "Azure": - redirect = config_obj.GUI.PublicUrl + "auth/azure/callback" + redirect.Path = path.Join(redirect.Path, "auth/azure/callback") case "OIDC": - redirect = config_obj.GUI.PublicUrl + "auth/oidc/callback" + redirect.Path = path.Join(redirect.Path, "auth/oidc/callback") } fmt.Printf("\nSetting %v configuration will use redirect URL %v\n", - config_obj.GUI.Authenticator.Type, redirect) + config_obj.GUI.Authenticator.Type, redirect.String()) switch config_obj.GUI.Authenticator.Type { case "Google", "GitHub": diff --git a/file_store/test_utils/testsuite.go b/file_store/test_utils/testsuite.go index f5cf6812f28..2981ca9da2d 100644 --- a/file_store/test_utils/testsuite.go +++ b/file_store/test_utils/testsuite.go @@ -10,6 +10,10 @@ import ( "github.com/Velocidex/yaml/v2" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + actions_proto "www.velocidex.com/golang/velociraptor/actions/proto" + flows_proto "www.velocidex.com/golang/velociraptor/flows/proto" + "www.velocidex.com/golang/velociraptor/utils" + "www.velocidex.com/golang/velociraptor/vql/acl_managers" "www.velocidex.com/golang/velociraptor/vtesting/assert" artifacts_proto "www.velocidex.com/golang/velociraptor/artifacts/proto" @@ -110,6 +114,41 @@ type TestSuite struct { Services *orgs.ServiceContainer } +func (self *TestSuite) CreateClient(client_id string) { + client_info_manager, err := services.GetClientInfoManager(self.ConfigObj) + assert.NoError(self.T(), err) + + err = client_info_manager.Set(self.Ctx, &services.ClientInfo{ + actions_proto.ClientInfo{ + ClientId: client_id, + }}) + assert.NoError(self.T(), err) +} + +func (self *TestSuite) CreateFlow(client_id, flow_id string) { + defer utils.SetFlowIdForTests(flow_id)() + + launcher, err := services.GetLauncher(self.ConfigObj) + assert.NoError(self.T(), err) + + manager, err := services.GetRepositoryManager(self.ConfigObj) + assert.NoError(self.T(), err) + + repository, err := manager.GetGlobalRepository(self.ConfigObj) + require.NoError(self.T(), err) + + _, err = launcher.ScheduleArtifactCollection( + self.Ctx, + self.ConfigObj, + acl_managers.NullACLManager{}, + repository, + &flows_proto.ArtifactCollectorArgs{ + ClientId: client_id, + Artifacts: []string{"Generic.Client.Info"}, + }, nil) + assert.NoError(self.T(), err) +} + func (self *TestSuite) LoadConfig() *config_proto.Config { os.Setenv("VELOCIRAPTOR_CONFIG", SERVER_CONFIG) config_obj, err := new(config.Loader). diff --git a/gui/velociraptor/package-lock.json b/gui/velociraptor/package-lock.json index 679b025b702..effb5b255b3 100644 --- a/gui/velociraptor/package-lock.json +++ b/gui/velociraptor/package-lock.json @@ -15,7 +15,7 @@ "@fortawesome/free-solid-svg-icons": "^6.6.0", "@fortawesome/react-fontawesome": "0.2.2", "@popperjs/core": "^2.11.8", - "ace-builds": "1.36.1", + "ace-builds": "1.36.3", "axios": ">=1.7.5", "axios-retry": "3.9.1", "bootstrap": "5.3.3", @@ -55,7 +55,7 @@ "react-simple-snackbar": "^1.1.11", "react-split-pane": "^0.1.92", "react-step-wizard": "^5.3.11", - "recharts": "^2.12.7", + "recharts": "^2.13.2", "sprintf-js": "1.1.3", "url-parse": "^1.5.10", "webpack": "^5.95.0" @@ -2089,9 +2089,9 @@ "dev": true }, "node_modules/@babel/runtime": { - "version": "7.25.7", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.7.tgz", - "integrity": "sha512-FjoyLe754PMiYsFaN5C94ttGiOmBNYTf6pLr4xXHAT5uctHb092PBszndLDR5XA/jghQvn4n7JMHl7dmTgbm9w==", + "version": "7.26.0", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.26.0.tgz", + "integrity": "sha512-FDSOghenHTiToteC/QRlv2q3DhPZ/oOXTBoirfWNx1Cx3TMVcGWQtMMmQcSvb/JjpNeGzx8Pq/b4fKEJuWm1sw==", "license": "MIT", "dependencies": { "regenerator-runtime": "^0.14.0" @@ -8245,9 +8245,9 @@ } }, "node_modules/recharts": { - "version": "2.12.7", - "resolved": "https://registry.npmjs.org/recharts/-/recharts-2.12.7.tgz", - "integrity": "sha512-hlLJMhPQfv4/3NBSAyq3gzGg4h2v69RJh6KU7b3pXYNNAELs9kEoXOjbkxdXpALqKBoVmVptGfLpxdaVYqjmXQ==", + "version": "2.13.2", + "resolved": "https://registry.npmjs.org/recharts/-/recharts-2.13.2.tgz", + "integrity": "sha512-UDLGFmnsBluDIPpQb9uty0ejb+jiVI71vkki8vVsR6ZCJdgjBfKQoQfft4re99CKlTy9qjQApxCLG6TrxJkeAg==", "license": "MIT", "dependencies": { "clsx": "^2.0.0", @@ -10958,9 +10958,9 @@ "dev": true }, "@babel/runtime": { - "version": "7.25.7", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.7.tgz", - "integrity": "sha512-FjoyLe754PMiYsFaN5C94ttGiOmBNYTf6pLr4xXHAT5uctHb092PBszndLDR5XA/jghQvn4n7JMHl7dmTgbm9w==", + "version": "7.26.0", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.26.0.tgz", + "integrity": "sha512-FDSOghenHTiToteC/QRlv2q3DhPZ/oOXTBoirfWNx1Cx3TMVcGWQtMMmQcSvb/JjpNeGzx8Pq/b4fKEJuWm1sw==", "requires": { "regenerator-runtime": "^0.14.0" } @@ -15443,9 +15443,9 @@ } }, "recharts": { - "version": "2.12.7", - "resolved": "https://registry.npmjs.org/recharts/-/recharts-2.12.7.tgz", - "integrity": "sha512-hlLJMhPQfv4/3NBSAyq3gzGg4h2v69RJh6KU7b3pXYNNAELs9kEoXOjbkxdXpALqKBoVmVptGfLpxdaVYqjmXQ==", + "version": "2.13.2", + "resolved": "https://registry.npmjs.org/recharts/-/recharts-2.13.2.tgz", + "integrity": "sha512-UDLGFmnsBluDIPpQb9uty0ejb+jiVI71vkki8vVsR6ZCJdgjBfKQoQfft4re99CKlTy9qjQApxCLG6TrxJkeAg==", "requires": { "clsx": "^2.0.0", "eventemitter3": "^4.0.1", diff --git a/gui/velociraptor/package.json b/gui/velociraptor/package.json index 6eb4f3aa706..ccaa5fa8026 100644 --- a/gui/velociraptor/package.json +++ b/gui/velociraptor/package.json @@ -4,7 +4,7 @@ "private": true, "type": "module", "dependencies": { - "@babel/runtime": "^7.25.7", + "@babel/runtime": "^7.26.0", "@fortawesome/fontawesome-svg-core": "6.6.0", "@fortawesome/free-regular-svg-icons": "6.6.0", "@fortawesome/free-solid-svg-icons": "^6.6.0", @@ -50,7 +50,7 @@ "react-simple-snackbar": "^1.1.11", "react-split-pane": "^0.1.92", "react-step-wizard": "^5.3.11", - "recharts": "^2.12.7", + "recharts": "^2.13.2", "sprintf-js": "1.1.3", "url-parse": "^1.5.10", "webpack": "5.95.0" diff --git a/gui/velociraptor/src/components/core/api-service.jsx b/gui/velociraptor/src/components/core/api-service.jsx index 17b18239fbc..ca22a1f4789 100644 --- a/gui/velociraptor/src/components/core/api-service.jsx +++ b/gui/velociraptor/src/components/core/api-service.jsx @@ -268,7 +268,7 @@ const upload = function(url, files, params) { // * it either starts with base path or not - URLs that do not start // with the base path will be fixed later. const internal_links = new RegExp( - "^(" + base_path + ")?/(api|app|notebooks|downloads|hunts|clients)/"); + "^(" + base_path + ")?/(api|app|notebooks|downloads|hunts|clients|auth)/"); // Prepare a suitable href link for // This function accepts a number of options: diff --git a/gui/velociraptor/src/components/welcome/login.jsx b/gui/velociraptor/src/components/welcome/login.jsx index 2b8932519fb..a973dad5721 100644 --- a/gui/velociraptor/src/components/welcome/login.jsx +++ b/gui/velociraptor/src/components/welcome/login.jsx @@ -12,6 +12,7 @@ import github_logo from "./Github-octocat-icon-vector-01.svg"; import google_logo from "./Google-icon-vector-04.svg"; import azure_logo from "./Microsoft_Azure_Logo.svg"; import openid_logo from "./OpenID_logo.svg"; +import T from '../i8n/i8n.jsx'; class Authenticator extends Component { @@ -48,7 +49,7 @@ class Authenticator extends Component {