diff --git a/.gitignore b/.gitignore index e8df2da4..8e7ed7c1 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ __debug_* *.key .env **/*.local.* +**/*.pem diff --git a/.vscode/launch.json b/.vscode/launch.json index db081c57..49672fbd 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ // Comment out if you need to setup environment variables for the module "envFile": "${workspaceFolder}/.env", // "args": [ - // "--port=5570", + // "--BRUTE_FORCE_MAX_LOGIN_ATTEMPTS=10", // ] }, { diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b632571..0428d4c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,41 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.4.4] - 2024-01-12 ### Added -- Initial project setup +- brute force attack protection, this will lock accounts after x attempts, by + default 5 attempts and will use by default incremental wait periods for each + failed attempts, all of these parameters can be changed +- added the ability to sign a token with different algorithms, by default it will + use HS256, but you can change it to RS256, HS384, RS384, HS512, RS512, this will + cater for the request we had for asymmetric keys +- added a random secret generator for the default HS256 is none is provided, this + is a change from previous versions where we used the machine id as the secret + this will increment the security of the default installation +- added a password complexity pipeline for checking if the users passwords adhere + to the complexity requirements, this can be disabled if required, by default the + password complexity is enabled and the complexity is set to 12 characters, at least + one uppercase, one lowercase, one number and one special character +- added a diagnostics class to better cater for errors and exceptions, this will + allow us to better handle errors and exceptions and return a more meaningful + error message to the user a the moment is not used in all of the code, but we + will be adding it to all of the code in the future + +### Changed + +- added back the ability to hash passwords using the SHA256 algorithm, this was + removed in a previous version, but we have added it back as some users already + had passwords hashed using this algorithm and this was breaking them. the default + installation will use the bcrypt algorithm + +### Fixed + +- fixed an issue where the token validation endpoint was not working and only accepted + GET requests, it now accepts only POST requests as expected and documented -## [0.4.3] - 2024-01-12 +## [0.4.3] - 2024-01-09 ### Added @@ -21,3 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - fixed a bug where a host would not show it status correctly + +## [Unreleased] + +### Added + +- Initial project setup diff --git a/src/catalog/models/pull_catalog_manifest.go b/src/catalog/models/pull_catalog_manifest.go index 856e88b6..1ab4628c 100644 --- a/src/catalog/models/pull_catalog_manifest.go +++ b/src/catalog/models/pull_catalog_manifest.go @@ -49,7 +49,7 @@ func (r *PullCatalogManifestRequest) Validate() error { } ctx := basecontext.NewRootBaseContext() - svcCtl := system.Get(ctx) + svcCtl := system.Get() arch, err := svcCtl.GetArchitecture(ctx) if err != nil { return errors.New("unable to determine architecture") diff --git a/src/catalog/models/push_catalog_manifest.go b/src/catalog/models/push_catalog_manifest.go index 48839487..62a36888 100644 --- a/src/catalog/models/push_catalog_manifest.go +++ b/src/catalog/models/push_catalog_manifest.go @@ -61,7 +61,7 @@ func (r *PushCatalogManifestRequest) Validate() error { if r.Architecture == "" { ctx := basecontext.NewRootBaseContext() - sysCtl := system.Get(ctx) + sysCtl := system.Get() arch, err := sysCtl.GetArchitecture(ctx) if err != nil { return errors.NewWithCode("unable to determine architecture and none was set", 400) diff --git a/src/catalog/pull.go b/src/catalog/pull.go index d3355192..be3f8ab0 100644 --- a/src/catalog/pull.go +++ b/src/catalog/pull.go @@ -77,7 +77,7 @@ func (s *CatalogManifestService) Pull(ctx basecontext.ApiContext, r *models.Pull manifest = &models.VirtualMachineCatalogManifest{} manifest.Provider = &provider apiClient.SetAuthorization(GetAuthenticator(manifest.Provider)) - srvCtl := system.Get(ctx) + srvCtl := system.Get() arch, err := srvCtl.GetArchitecture(ctx) if err != nil { response.AddError(errors.New("unable to determine architecture")) diff --git a/src/cmd/api.go b/src/cmd/api.go index 6523bd0f..205c68ac 100644 --- a/src/cmd/api.go +++ b/src/cmd/api.go @@ -8,9 +8,9 @@ import ( "github.com/Parallels/pd-api-service/common" "github.com/Parallels/pd-api-service/config" "github.com/Parallels/pd-api-service/constants" - "github.com/Parallels/pd-api-service/helpers" "github.com/Parallels/pd-api-service/orchestrator" "github.com/Parallels/pd-api-service/restapi" + "github.com/Parallels/pd-api-service/security/password" "github.com/Parallels/pd-api-service/serviceprovider" "github.com/Parallels/pd-api-service/startup" "github.com/cjlapao/common-go/helper" @@ -29,8 +29,8 @@ func processApi(ctx basecontext.ApiContext) { if cfg.GetSecurityKey() == "" { common.Logger.Warn("No security key found, database will be unencrypted") } - startup.Start() + startup.Init() currentUser, err := serviceprovider.Get().System.GetCurrentUser(ctx) if err != nil { @@ -49,7 +49,8 @@ func processApi(ctx basecontext.ApiContext) { rootUser, _ := db.GetUser(ctx, "root") rootPassword := os.Getenv(constants.ROOT_PASSWORD_ENV_VAR) if rootUser != nil { - if err := helpers.BcryptCompare(rootPassword, rootUser.ID, rootUser.Password); err != nil { + passwdSvc := password.Get() + if err := passwdSvc.Compare(rootPassword, rootUser.ID, rootUser.Password); err != nil { ctx.LogInfo("Updating root password") if err := db.UpdateRootPassword(ctx, os.Getenv(constants.ROOT_PASSWORD_ENV_VAR)); err != nil { panic(err) diff --git a/src/cmd/catalog.go b/src/cmd/catalog.go index b16c8a87..381e3556 100644 --- a/src/cmd/catalog.go +++ b/src/cmd/catalog.go @@ -70,7 +70,7 @@ func catalogInitPdFile(ctx basecontext.ApiContext, cmd string) *pdfile.PDFile { } if pdFile.Owner == "" { - user, _ := system.Get(ctx).GetCurrentUser(ctx) + user, _ := system.Get().GetCurrentUser(ctx) if user != "" { pdFile.Owner = user } diff --git a/src/cmd/generate_security_key.go b/src/cmd/generate_security_key.go index 41dfc95a..d7cc0725 100644 --- a/src/cmd/generate_security_key.go +++ b/src/cmd/generate_security_key.go @@ -2,6 +2,7 @@ package cmd import ( "os" + "strconv" "github.com/Parallels/pd-api-service/basecontext" "github.com/Parallels/pd-api-service/constants" @@ -10,14 +11,24 @@ import ( ) func processGenerateSecurityKey(ctx basecontext.ApiContext) { - ctx.LogInfo("Generating security key") filename := "private.key" if helper.GetFlagValue(constants.FILE_FLAG, "") != "" { filename = helper.GetFlagValue(constants.FILE_FLAG, "") } + keySize := 2048 + if helper.GetFlagValue(constants.SIZE_FLAG, "") != "" { + size, err := strconv.Atoi(helper.GetFlagValue(constants.SIZE_FLAG, "0")) + if err != nil { + ctx.LogError("Error parsing size flag: %s", err.Error()) + } else { + keySize = size + } + } + + ctx.LogInfo("Generating security key, with size %v", keySize) - err := security.GenPrivateRsaKey(filename) + err := security.GenPrivateRsaKey(filename, keySize) if err != nil { panic(err) } diff --git a/src/cmd/main.go b/src/cmd/main.go index 6de19830..f9fe904a 100644 --- a/src/cmd/main.go +++ b/src/cmd/main.go @@ -34,5 +34,6 @@ func Process() { default: processApi(ctx) } + os.Exit(0) } diff --git a/src/config/main.go b/src/config/main.go index 65134dd5..c8abe005 100644 --- a/src/config/main.go +++ b/src/config/main.go @@ -152,7 +152,7 @@ func (c *Config) GetTokenDurationMinutes() int { func (c *Config) GetRootFolder() (string, error) { ctx := basecontext.NewRootBaseContext() - srv := system.Get(ctx) + srv := system.Get() currentUser, err := srv.GetCurrentUser(ctx) if err != nil { currentUser = "root" @@ -284,3 +284,40 @@ func (c *Config) UseOrchestratorResources() bool { return false } + +func (c *Config) GetKey(key string) string { + value := os.Getenv(key) + if value == "" { + value = helper.GetFlagValue(key, "") + } + + return value +} + +func (c *Config) GetIntKey(key string) int { + value := c.GetKey(key) + if value == "" { + return 0 + } + + intVal, err := strconv.Atoi(value) + if err != nil { + return 0 + } + + return intVal +} + +func (c *Config) GetBoolKey(key string) bool { + value := c.GetKey(key) + if value == "" { + return false + } + + boolVal, err := strconv.ParseBool(value) + if err != nil { + return false + } + + return boolVal +} diff --git a/src/constants/brute_force_guard.go b/src/constants/brute_force_guard.go new file mode 100644 index 00000000..25426475 --- /dev/null +++ b/src/constants/brute_force_guard.go @@ -0,0 +1,7 @@ +package constants + +const ( + BRUTE_FORCE_MAX_LOGIN_ATTEMPTS_ENV_VAR = "BRUTE_FORCE_MAX_LOGIN_ATTEMPTS" + BRUTE_FORCE_LOCKOUT_DURATION_ENV_VAR = "BRUTE_FORCE_LOCKOUT_DURATION" + BRUTE_FORCE_INCREMENTAL_WAIT_ENV_VAR = "BRUTE_FORCE_INCREMENTAL_WAIT" +) diff --git a/src/constants/jwt.go b/src/constants/jwt.go new file mode 100644 index 00000000..8cb6df19 --- /dev/null +++ b/src/constants/jwt.go @@ -0,0 +1,8 @@ +package constants + +const ( + JWT_PRIVATE_KEY_ENV_VAR = "JWT_PRIVATE_KEY" + JWT_HMACS_SECRET_ENV_VAR = "JWT_HMACS_SECRET" + JWT_DURATION_ENV_VAR = "JWT_DURATION" + JWT_SIGN_ALGORITHM_ENV_VAR = "JWT_SIGN_ALGORITHM" +) diff --git a/src/constants/main.go b/src/constants/main.go index ebe1bd05..51d22ea1 100644 --- a/src/constants/main.go +++ b/src/constants/main.go @@ -71,6 +71,7 @@ const ( API_PORT_FLAG = "port" UPDATE_ROOT_PASSWORD_FLAG = "update-root-pass" FILE_FLAG = "file" + SIZE_FLAG = "size" MODE_FLAG = "mode" HELP_FLAG = "help" PASSWORD_FLAG = "password" diff --git a/src/constants/password.go b/src/constants/password.go new file mode 100644 index 00000000..326f546d --- /dev/null +++ b/src/constants/password.go @@ -0,0 +1,11 @@ +package constants + +const ( + SECURITY_PASSWORD_MIN_PASSWORD_LENGTH_ENV_VAR = "SECURITY_PASSWORD_MIN_PASSWORD_LENGTH" + SECURITY_PASSWORD_MAX_PASSWORD_LENGTH_ENV_VAR = "SECURITY_PASSWORD_MAX_PASSWORD_LENGTH" + SECURITY_PASSWORD_REQUIRE_LOWERCASE_ENV_VAR = "SECURITY_PASSWORD_REQUIRE_LOWERCASE" + SECURITY_PASSWORD_REQUIRE_UPPERCASE_ENV_VAR = "SECURITY_PASSWORD_REQUIRE_UPPERCASE" + SECURITY_PASSWORD_REQUIRE_NUMBER_ENV_VAR = "SECURITY_PASSWORD_REQUIRE_NUMBER" + SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR_ENV_VAR = "SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR" + SECURITY_PASSWORD_SALT_PASSWORD_ENV_VAR = "SECURITY_PASSWORD_SALT_PASSWORD" +) diff --git a/src/controllers/authorization.go b/src/controllers/authorization.go index 01e66b84..ffd90c2e 100644 --- a/src/controllers/authorization.go +++ b/src/controllers/authorization.go @@ -3,18 +3,16 @@ package controllers import ( "encoding/json" "net/http" - "time" "github.com/Parallels/pd-api-service/basecontext" - "github.com/Parallels/pd-api-service/config" - "github.com/Parallels/pd-api-service/data" - "github.com/Parallels/pd-api-service/helpers" "github.com/Parallels/pd-api-service/models" "github.com/Parallels/pd-api-service/restapi" + bruteforceguard "github.com/Parallels/pd-api-service/security/brute_force_guard" + "github.com/Parallels/pd-api-service/security/jwt" + "github.com/Parallels/pd-api-service/security/password" "github.com/Parallels/pd-api-service/serviceprovider" "github.com/cjlapao/common-go/helper/http_helper" - "github.com/dgrijalva/jwt-go" ) func registerAuthorizationHandlers(ctx basecontext.ApiContext, version string) { @@ -27,7 +25,7 @@ func registerAuthorizationHandlers(ctx basecontext.ApiContext, version string) { Register() restapi.NewController(). - WithMethod(restapi.GET). + WithMethod(restapi.POST). WithVersion(version). WithPath("/auth/token/validate"). WithHandler(ValidateTokenHandler()). @@ -46,7 +44,7 @@ func registerAuthorizationHandlers(ctx basecontext.ApiContext, version string) { func GetTokenHandler() restapi.ControllerHandler { return func(w http.ResponseWriter, r *http.Request) { ctx := GetBaseContext(r) - cfg := config.NewConfig() + var request models.LoginRequest if err := http_helper.MapRequestBody(r, &request); err != nil { ReturnApiError(ctx, w, models.ApiErrorResponse{ @@ -70,61 +68,77 @@ func GetTokenHandler() restapi.ControllerHandler { user, err := dbService.GetUser(ctx, request.Email) if err != nil { - ReturnApiError(ctx, w, models.NewFromError(err)) + ReturnApiError(ctx, w, models.ApiErrorResponse{ + Message: "Invalid User or Password", + Code: http.StatusUnauthorized, + }) return } if user == nil { ReturnApiError(ctx, w, models.ApiErrorResponse{ - Message: data.ErrUserNotFound.Error(), + Message: "Invalid User or Password", Code: http.StatusUnauthorized, }) return } - if err := helpers.BcryptCompare(request.Password, user.ID, user.Password); err != nil { + bruteForceSvc := bruteforceguard.Get() + + passwdSvc := password.Get() + if err := passwdSvc.Compare(request.Password, user.ID, user.Password); err != nil { ReturnApiError(ctx, w, models.ApiErrorResponse{ - Message: "Invalid Password", + Message: "Invalid User or Password", Code: http.StatusUnauthorized, }) + + if diag := bruteForceSvc.Process(user.ID, false, "Invalid Password"); diag.HasErrors() { + ctx.LogError("Error processing brute force guard: %v", diag) + } return } - roles := make([]string, 0) - claims := make([]string, 0) - for _, role := range user.Roles { - roles = append(roles, role.Name) + userRoles := make([]string, 0) + userClaims := make([]string, 0) + for _, userRole := range user.Roles { + userRoles = append(userRoles, userRole.Name) } - for _, claim := range user.Claims { - claims = append(claims, claim.Name) + for _, userClaim := range user.Claims { + userClaims = append(userClaims, userClaim.Name) } - expiresAt := time.Now().Add(time.Minute * time.Duration(cfg.GetTokenDurationMinutes())).Unix() - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + claims := map[string]interface{}{ "email": request.Email, - "roles": roles, - "claims": claims, - "exp": expiresAt, - }) - - // We either signing the token with the HMAC secret or the secret from the config - var key []byte - if cfg.GetHmacSecret() == "" { - key = []byte(cfg.GetHmacSecret()) - } else { - key = []byte(serviceprovider.Get().HardwareSecret) + "uid": user.ID, + "roles": userRoles, + "claims": userClaims, } - - tokenString, err := token.SignedString(key) + tokenSvc := jwt.Get() + tokenStr, err := tokenSvc.Sign(claims) + if err != nil { + ReturnApiError(ctx, w, models.NewFromErrorWithCode(err, 401)) + if diag := bruteForceSvc.Process(user.ID, false, err.Error()); diag.HasErrors() { + ctx.LogError("Error processing brute force guard: %v", diag) + } + return + } + token, err := tokenSvc.Parse(tokenStr) if err != nil { ReturnApiError(ctx, w, models.NewFromErrorWithCode(err, 401)) + if diag := bruteForceSvc.Process(user.ID, false, err.Error()); diag.HasErrors() { + ctx.LogError("Error processing brute force guard: %v", diag) + } return } response := models.LoginResponse{ - Token: tokenString, + Token: tokenStr, Email: request.Email, - ExpiresAt: expiresAt, + ExpiresAt: int64(token.Claims["exp"].(float64)), + } + + if diag := bruteForceSvc.Process(user.ID, true, "Success"); diag.HasErrors() { + ctx.LogError("Error processing brute force guard: %v", diag) } w.WriteHeader(http.StatusOK) @@ -145,7 +159,7 @@ func GetTokenHandler() restapi.ControllerHandler { func ValidateTokenHandler() restapi.ControllerHandler { return func(w http.ResponseWriter, r *http.Request) { ctx := GetBaseContext(r) - cfg := config.NewConfig() + var request models.ValidateTokenRequest if err := http_helper.MapRequestBody(r, &request); err != nil { ReturnApiError(ctx, w, models.ApiErrorResponse{ @@ -161,37 +175,29 @@ func ValidateTokenHandler() restapi.ControllerHandler { return } - token, err := jwt.Parse(request.Token, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, jwt.ErrSignatureInvalid - } - - // We either signing the token with the HMAC secret or the secret from the config - var key []byte - if cfg.GetHmacSecret() == "" { - key = []byte(cfg.GetHmacSecret()) - } else { - key = []byte(serviceprovider.Get().HardwareSecret) - } - return key, nil - }) - + tokenSvc := jwt.Get() + token, err := tokenSvc.Parse(request.Token) if err != nil { ReturnApiError(ctx, w, models.NewFromErrorWithCode(err, 401)) return } - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(models.ValidateTokenResponse{ - Valid: true, - }) - ctx.LogInfo("Token for user %s is valid", claims["email"]) + isValid, err := token.Valid() + if err != nil { + ReturnApiError(ctx, w, models.NewFromErrorWithCode(err, 401)) return - } else { + } + + if !isValid { ReturnApiError(ctx, w, models.NewFromErrorWithCode(err, 401)) - ctx.LogError("Token is invalid") return } + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(models.ValidateTokenResponse{ + Valid: true, + }) + email, _ := token.GetEmail() + ctx.LogInfo("Token for user %s is valid", email) } } diff --git a/src/controllers/machines.go b/src/controllers/machines.go index ad81e5cd..c9ade755 100644 --- a/src/controllers/machines.go +++ b/src/controllers/machines.go @@ -894,7 +894,7 @@ func CreateVirtualMachineHandler() restapi.ControllerHandler { // Attempt to get the architecture from the system if request.Architecture == "" { - svcCtl := system.Get(ctx) + svcCtl := system.Get() arch, err := svcCtl.GetArchitecture(ctx) if err != nil { ReturnApiError(ctx, w, models.ApiErrorResponse{ diff --git a/src/controllers/users.go b/src/controllers/users.go index d06fb5c7..057d91dd 100644 --- a/src/controllers/users.go +++ b/src/controllers/users.go @@ -9,6 +9,7 @@ import ( "github.com/Parallels/pd-api-service/mappers" "github.com/Parallels/pd-api-service/models" "github.com/Parallels/pd-api-service/restapi" + "github.com/Parallels/pd-api-service/security/password" "github.com/Parallels/pd-api-service/serviceprovider" "github.com/cjlapao/common-go/helper/http_helper" @@ -212,6 +213,24 @@ func CreateUserHandler() restapi.ControllerHandler { }) return } + if request.Password != "" { + passwordSvc := password.Get() + if valid, diag := passwordSvc.CheckPasswordComplexity(request.Password); diag.HasErrors() { + ReturnApiError(ctx, w, models.ApiErrorResponse{ + Message: diag.Error(), + Code: http.StatusBadRequest, + }) + return + } else { + if !valid { + ReturnApiError(ctx, w, models.ApiErrorResponse{ + Message: "Invalid Password, please check complexity rules", + Code: http.StatusBadRequest, + }) + return + } + } + } dbService, err := serviceprovider.GetDatabaseService(ctx) if err != nil { diff --git a/src/data/api_key.go b/src/data/api_key.go index 630e2a4a..c71aebe5 100644 --- a/src/data/api_key.go +++ b/src/data/api_key.go @@ -7,6 +7,7 @@ import ( "github.com/Parallels/pd-api-service/data/models" "github.com/Parallels/pd-api-service/errors" "github.com/Parallels/pd-api-service/helpers" + "github.com/Parallels/pd-api-service/security/password" ) var ( @@ -58,8 +59,8 @@ func (j *JsonDatabase) CreateApiKey(ctx basecontext.ApiContext, apiKey models.Ap return nil, ErrApiKeyAlreadyExists } - // Hash the password with SHA-256 - hashSecret, err := helpers.BcryptHash(apiKey.Secret, apiKey.ID) + passwdSvc := password.Get() + hashSecret, err := passwdSvc.Hash(apiKey.Secret, apiKey.ID) if err != nil { return nil, err } diff --git a/src/data/models/user.go b/src/data/models/user.go index a36a1065..c1e9f01d 100644 --- a/src/data/models/user.go +++ b/src/data/models/user.go @@ -1,13 +1,17 @@ package models type User struct { - ID string `json:"id,omitempty"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` - Password string `json:"password,omitempty"` - CreatedAt string `json:"created_at,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` - Roles []Role `json:"roles,omitempty"` - Claims []Claim `json:"claims,omitempty"` + ID string `json:"id,omitempty"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + Password string `json:"password,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + Roles []Role `json:"roles,omitempty"` + Claims []Claim `json:"claims,omitempty"` + FailedLoginAttempts int `json:"failed_login_attempts,omitempty"` + Blocked bool `json:"blocked,omitempty"` + BlockedSince string `json:"blocked_since,omitempty"` + BlockedReason string `json:"blocked_reason,omitempty"` } diff --git a/src/data/users.go b/src/data/users.go index 36d177ff..3916706a 100644 --- a/src/data/users.go +++ b/src/data/users.go @@ -8,6 +8,7 @@ import ( "github.com/Parallels/pd-api-service/data/models" "github.com/Parallels/pd-api-service/errors" "github.com/Parallels/pd-api-service/helpers" + "github.com/Parallels/pd-api-service/security/password" ) var ( @@ -127,8 +128,8 @@ func (j *JsonDatabase) CreateUser(ctx basecontext.ApiContext, user models.User) } } - // Hash the password with SHA-256 - hashedPassword, err := helpers.BcryptHash(user.Password, user.ID) + passwdSvc := password.Get() + hashedPassword, err := passwdSvc.Hash(user.Password, user.ID) if err != nil { return nil, err } @@ -159,7 +160,8 @@ func (j *JsonDatabase) UpdateUser(ctx basecontext.ApiContext, key models.User) e j.data.Users[i].Name = key.Name } if key.Password != "" { - hashedPassword, err := helpers.BcryptHash(key.Password, key.ID) + passwdSvc := password.Get() + hashedPassword, err := passwdSvc.Hash(key.Password, key.ID) if err != nil { return err } @@ -177,6 +179,28 @@ func (j *JsonDatabase) UpdateUser(ctx basecontext.ApiContext, key models.User) e return ErrUserNotFound } +func (j *JsonDatabase) UpdateUserBlockStatus(ctx basecontext.ApiContext, key models.User) error { + if !j.IsConnected() { + return ErrDatabaseNotConnected + } + + for i, user := range j.data.Users { + if user.ID == key.ID { + j.data.Users[i].Blocked = key.Blocked + j.data.Users[i].BlockedSince = key.BlockedSince + j.data.Users[i].BlockedReason = key.BlockedReason + j.data.Users[i].FailedLoginAttempts = key.FailedLoginAttempts + j.data.Users[i].UpdatedAt = helpers.GetUtcCurrentDateTime() + if err := j.Save(ctx); err != nil { + return err + } + return nil + } + } + + return ErrUserNotFound +} + func (j *JsonDatabase) UpdateRootPassword(ctx basecontext.ApiContext, newPassword string) error { if !j.IsConnected() { return ErrDatabaseNotConnected @@ -184,7 +208,8 @@ func (j *JsonDatabase) UpdateRootPassword(ctx basecontext.ApiContext, newPasswor for i, user := range j.data.Users { if user.Email == "root@localhost" { - hashedPassword, err := helpers.BcryptHash(newPassword, user.ID) + passwdSvc := password.Get() + hashedPassword, err := passwdSvc.Hash(newPassword, user.ID) if err != nil { return err } diff --git a/src/errors/common.go b/src/errors/common.go index 611e7918..43ef0d21 100644 --- a/src/errors/common.go +++ b/src/errors/common.go @@ -1,5 +1,9 @@ package errors +func ErrNotFound() error { + return NewWithCode("not found", 404) +} + func ErrValueEmpty() error { return NewWithCode("value cannot be empty", 400) } diff --git a/src/errors/diagnostic.go b/src/errors/diagnostic.go new file mode 100644 index 00000000..3c00620f --- /dev/null +++ b/src/errors/diagnostic.go @@ -0,0 +1,76 @@ +package errors + +import ( + "fmt" + "strings" +) + +type Diagnostics struct { + errors []error + warnings []error +} + +func NewDiagnostics() *Diagnostics { + return &Diagnostics{ + errors: []error{}, + warnings: []error{}, + } +} + +func (d *Diagnostics) AddError(err error) { + d.errors = append(d.errors, err) +} + +func (d *Diagnostics) AddWarning(err error) { + d.warnings = append(d.warnings, err) +} + +func (d *Diagnostics) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d *Diagnostics) HasWarnings() bool { + return len(d.warnings) > 0 +} + +func (d *Diagnostics) Errors() []error { + return d.errors +} + +func (d *Diagnostics) Warnings() []error { + return d.warnings +} + +func (d *Diagnostics) Append(diagnostics *Diagnostics) { + d.errors = append(d.errors, diagnostics.errors...) + d.warnings = append(d.warnings, diagnostics.warnings...) +} + +func (d *Diagnostics) Error() string { + msg := "" + if len(d.errors) > 0 { + if len(d.errors) == 1 { + return d.errors[0].Error() + } else { + msg = "errors:\n" + for _, err := range d.errors { + errMsg := strings.ReplaceAll(err.Error(), "error: ", "") + msg = fmt.Sprintf("%v\t%v\n", msg, errMsg) + } + } + } + + if len(d.warnings) > 0 { + if len(d.warnings) == 1 { + return d.warnings[0].Error() + } else { + msg = "warnings:\n" + for _, err := range d.errors { + errMsg := strings.ReplaceAll(err.Error(), "error: ", "") + msg = fmt.Sprintf("%v\t%v\n", msg, errMsg) + } + } + } + + return msg +} diff --git a/src/go.mod b/src/go.mod index 42443bf3..165de287 100644 --- a/src/go.mod +++ b/src/go.mod @@ -6,9 +6,10 @@ require ( github.com/Azure/azure-storage-blob-go v0.15.0 github.com/aws/aws-sdk-go v1.49.21 github.com/cjlapao/common-go v0.0.39 + github.com/cjlapao/common-go-cryptorand v0.0.6 github.com/cjlapao/common-go-logger v0.0.5 - github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/go-sql-driver/mysql v1.7.1 + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/uuid v1.5.0 github.com/gorilla/handlers v1.5.2 github.com/gorilla/mux v1.8.1 @@ -17,6 +18,7 @@ require ( github.com/swaggo/files/v2 v2.0.0 github.com/swaggo/swag v1.16.2 golang.org/x/crypto v0.18.0 + gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -43,7 +45,6 @@ require ( github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/spec v0.20.9 // indirect github.com/go-openapi/swag v0.22.4 // indirect - github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.2 // indirect github.com/gookit/color v1.5.4 // indirect diff --git a/src/go.sum b/src/go.sum index de3d6eff..524d0114 100644 --- a/src/go.sum +++ b/src/go.sum @@ -37,6 +37,8 @@ github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1l github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/cjlapao/common-go v0.0.39 h1:bAAUrj2B9v0kMzbAOhzjSmiyDy+rd56r2sy7oEiQLlA= github.com/cjlapao/common-go v0.0.39/go.mod h1:M3dzazLjTjEtZJbbxoA5ZDiGCiHmpwqW9l4UWaddwOA= +github.com/cjlapao/common-go-cryptorand v0.0.6 h1:0XpMIlu2Hbu5JEq4O/3RxUgo68h21mkElak5HxdjhuQ= +github.com/cjlapao/common-go-cryptorand v0.0.6/go.mod h1:IR5isk32OIQ/yLbZUOmKR7vVo5OzTpfeX0xAagHsQyU= github.com/cjlapao/common-go-logger v0.0.5 h1:YyO0lA4Uav6jwD4PT2gfv9Iu0rj5MtlOdKyop4Y5gFQ= github.com/cjlapao/common-go-logger v0.0.5/go.mod h1:bF2s2y2as4Fwz2Ox3QTkWw1Y02a07jX5ey8smY5inrU= github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= @@ -47,8 +49,6 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= @@ -295,6 +295,8 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/src/helpers/strings.go b/src/helpers/strings.go index 8f524505..fdbbf4d1 100644 --- a/src/helpers/strings.go +++ b/src/helpers/strings.go @@ -2,14 +2,12 @@ package helpers import ( "crypto/rand" - "crypto/sha256" "encoding/hex" "math" "strconv" "strings" "github.com/Parallels/pd-api-service/errors" - "golang.org/x/crypto/bcrypt" ) func GenerateId() string { @@ -20,47 +18,6 @@ func GenerateId() string { return hex.EncodeToString(bytes) } -func Sha256Hash(input string) (string, error) { - hashedPassword := sha256.Sum256([]byte(input)) - return hex.EncodeToString(hashedPassword[:]), nil -} - -func BcryptHash(input string, salt string) (string, error) { - cost := bcrypt.DefaultCost - // saltString := GenerateSalt(salt, cost) - inputBytes := []byte(input) - saltBytes := []byte(salt) - if len(inputBytes) > 40 { - return "", errors.New("password cannot be longer than 42 characters") - } - if len(saltBytes) > 32 { - saltBytes = saltBytes[:32] - } - - saltedPwd := []byte(input + string(saltBytes)) - - bytes, err := bcrypt.GenerateFromPassword([]byte(saltedPwd), cost) - if err != nil { - return "", err - } - return string(bytes), nil -} - -func BcryptCompare(input string, salt string, hashedPwd string) error { - saltBytes := []byte(salt) - if len(saltBytes) > 32 { - saltBytes = saltBytes[:32] - } - - saltedPwd := []byte(input + string(saltBytes)) - - err := bcrypt.CompareHashAndPassword([]byte(hashedPwd), saltedPwd) - if err != nil { - return err - } - return nil -} - func ConvertByteToGigabyte(bytes float64) float64 { gb := float64(bytes) / 1024 / 1024 / 1024 return math.Round(gb*100) / 100 diff --git a/src/helpers/strings_test.go b/src/helpers/strings_test.go deleted file mode 100644 index 1a7fd1ba..00000000 --- a/src/helpers/strings_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package helpers - -import ( - "testing" - - "golang.org/x/crypto/bcrypt" -) - -func TestBcryptHash(t *testing.T) { - input := "password" - salt := "somesalt" - - hashedPwd, err := BcryptHash(input, salt) - if err != nil { - t.Errorf("Error hashing password: %v", err) - } - - err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(input+salt)) - if err != nil { - t.Errorf("Hashed password does not match input: %v", err) - } -} diff --git a/src/restapi/apikey_authorization_middleware.go b/src/restapi/apikey_authorization_middleware.go index 4c7e76c2..153a8a64 100644 --- a/src/restapi/apikey_authorization_middleware.go +++ b/src/restapi/apikey_authorization_middleware.go @@ -7,8 +7,8 @@ import ( "github.com/Parallels/pd-api-service/basecontext" "github.com/Parallels/pd-api-service/constants" "github.com/Parallels/pd-api-service/errors" - "github.com/Parallels/pd-api-service/helpers" "github.com/Parallels/pd-api-service/models" + "github.com/Parallels/pd-api-service/security/password" "github.com/Parallels/pd-api-service/serviceprovider" "net/http" @@ -69,7 +69,8 @@ func ApiKeyAuthorizationMiddlewareAdapter(roles []string, claims []string) Adapt } } if isValid { - if err := helpers.BcryptCompare(apiKey.Value, dbApiKey.ID, dbApiKey.Secret); err != nil { + passwdSvc := password.Get() + if err := passwdSvc.Compare(apiKey.Value, dbApiKey.ID, dbApiKey.Secret); err != nil { isValid = false authError.ErrorDescription = "Api Key is not Valid" } diff --git a/src/restapi/token_authorization_middleware.go b/src/restapi/token_authorization_middleware.go index 52404713..7c7b8c8a 100644 --- a/src/restapi/token_authorization_middleware.go +++ b/src/restapi/token_authorization_middleware.go @@ -3,19 +3,19 @@ package restapi import ( "context" "errors" + "fmt" "net/http" "strings" "github.com/Parallels/pd-api-service/basecontext" - "github.com/Parallels/pd-api-service/config" "github.com/Parallels/pd-api-service/constants" data_modules "github.com/Parallels/pd-api-service/data/models" "github.com/Parallels/pd-api-service/mappers" "github.com/Parallels/pd-api-service/models" + "github.com/Parallels/pd-api-service/security/jwt" "github.com/Parallels/pd-api-service/serviceprovider" "github.com/cjlapao/common-go/helper/http_helper" - "github.com/dgrijalva/jwt-go" ) // TokenAuthorizationMiddlewareAdapter validates a Authorization Bearer during a rest api call @@ -26,7 +26,6 @@ import ( func TokenAuthorizationMiddlewareAdapter(roles []string, claims []string) Adapter { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cfg := config.NewConfig() baseCtx := basecontext.NewBaseContextFromRequest(r) var authorizationContext *basecontext.AuthorizationContext authCtxFromRequest := baseCtx.GetAuthorizationContext() @@ -63,7 +62,7 @@ func TokenAuthorizationMiddlewareAdapter(roles []string, claims []string) Adapte // Setting the tenant in the context authorizationContext.Issuer = "Global" - //Starting authorization layer of the token + // Starting authorization layer of the token authorized := true baseCtx.LogInfo("Token Authorization layer started") @@ -76,136 +75,173 @@ func TokenAuthorizationMiddlewareAdapter(roles []string, claims []string) Adapte } // Validating userToken against the keys + var token *jwt.JwtSystemToken + // Validating if the token can be parsed if authorized { - token, err := jwt.Parse(jwt_token, func(token *jwt.Token) (interface{}, error) { - // Validate the algorithm - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, jwt.ErrSignatureInvalid + jwtSvc := jwt.Get() + var err error + token, err = jwtSvc.Parse(jwt_token) + if err != nil || token == nil { + authorized = false + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: err.Error(), } + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) + } + } - // Return the secret key used to sign the token - // We either signing the token with the HMAC secret or the secret from the config - var key []byte - if cfg.GetHmacSecret() == "" { - key = []byte(cfg.GetHmacSecret()) - } else { - key = []byte(serviceprovider.Get().HardwareSecret) + // Validating if the token is valid + if authorized { + valid, err := token.Valid() + if err != nil || !valid { + authorized = false + if err == nil { + err = errors.New("invalid token") } - return key, nil - }) + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: err.Error(), + } + + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) + } + } + // Validating if the token has the correct email + var email interface{} + if authorized { + var err error + email, err = token.GetClaim("email") if err != nil { authorized = false response := models.OAuthErrorResponse{ Error: models.OAuthUnauthorizedClient, ErrorDescription: err.Error(), } + authorizationContext.IsAuthorized = false authorizationContext.AuthorizationError = &response baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) } + } + // Validating if the token has the correct user + var dbUser *data_modules.User + if authorized { + db := serviceprovider.Get().JsonDatabase + var err error + + // validating if the database is connected + if err = db.Connect(baseCtx); err != nil { + authorized = false + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: fmt.Sprintf("Error connecting to database, %v", err.Error()), + } + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) + } + + // validating if the user exists if authorized { - // Check if the token is valid - if jwtClaims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - db := serviceprovider.Get().JsonDatabase - var dbUser *data_modules.User - var err error - if err = db.Connect(baseCtx); err != nil { - authorized = false - } else { - dbUser, err = db.GetUser(baseCtx, jwtClaims["email"].(string)) - if err != nil || dbUser == nil { - authorized = false + dbUser, err = db.GetUser(baseCtx, email.(string)) + if err != nil || dbUser == nil { + authorized = false + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: fmt.Sprintf("Error connecting to database, %v", err.Error()), + } + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) + } + } + + if authorized { + // Checking for the Super Duper User + authorizationContext.IsSuperUser = false + for _, userRole := range dbUser.Roles { + if strings.EqualFold(constants.SUPER_USER_ROLE, userRole.Name) { + authorizationContext.IsSuperUser = true + break + } + } + + // Validating if the user has the correct roles and claims + if !authorizationContext.IsSuperUser { + // Validating if the user has the correct roles + if len(roles) > 0 { + rolesCheck := TokenRoleClaimValidationList{} + for _, role := range roles { + roleCheck := &TokenRoleClaimValidation{Name: role} + for _, userRole := range dbUser.Roles { + if strings.EqualFold(role, userRole.Name) { + roleCheck.SetExists(true) + break + } + } + rolesCheck = append(rolesCheck, roleCheck) } - // Checking for the Super Duper User - authorizationContext.IsSuperUser = false - for _, userRole := range dbUser.Roles { - if strings.EqualFold(constants.SUPER_USER_ROLE, userRole.Name) { - authorizationContext.IsSuperUser = true - break + + if len(roles) != len(rolesCheck) || !rolesCheck.Exists() { + failed := rolesCheck.GetFailed() + authorized = false + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: fmt.Sprintf("User does not contain enough permissions, does not have roles, %v", failed), } + + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) } - if !authorizationContext.IsSuperUser { - // Checking if the user has the correct role required by the controller - if len(roles) > 0 { - contains := false - for _, role := range roles { - for _, userRole := range dbUser.Roles { - if strings.EqualFold(role, userRole.Name) { - contains = true - break - } - } - if contains { + } + + if authorized { + // Validating if the user has the correct claims + if len(claims) > 0 { + claimsCheck := TokenRoleClaimValidationList{} + for _, claim := range claims { + claimCheck := &TokenRoleClaimValidation{Name: claim} + for _, userClaim := range dbUser.Claims { + if strings.EqualFold(claim, userClaim.Name) { + claimCheck.SetExists(true) break } } - if !contains { - authorized = false - response := models.OAuthErrorResponse{ - Error: models.OAuthUnauthorizedClient, - ErrorDescription: "User does not contain enough permissions, not in role", - } - authorizationContext.IsAuthorized = false - authorizationContext.AuthorizationError = &response - baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) - } + claimsCheck = append(claimsCheck, claimCheck) } - if len(claims) > 0 { - contains := false - for _, claim := range claims { - for _, userClaim := range dbUser.Claims { - if strings.EqualFold(claim, userClaim.Name) { - contains = true - break - } - } - if contains { - break - } - } - if !contains { - authorized = false - response := models.OAuthErrorResponse{ - Error: models.OAuthUnauthorizedClient, - ErrorDescription: "User does not contain enough permissions, does not have claim", - } - authorizationContext.IsAuthorized = false - authorizationContext.AuthorizationError = &response - baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) + if len(claims) != len(claimsCheck) || !claimsCheck.Exists() { + failed := claimsCheck.GetFailed() + authorized = false + response := models.OAuthErrorResponse{ + Error: models.OAuthUnauthorizedClient, + ErrorDescription: fmt.Sprintf("User does not contain enough permissions, does not have claims, %v", failed), } + + authorizationContext.IsAuthorized = false + authorizationContext.AuthorizationError = &response + baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) } } } - if !authorized { - response := models.OAuthErrorResponse{ - Error: models.OAuthUnauthorizedClient, - ErrorDescription: "User not found", - } - authorizationContext.IsAuthorized = false - if authorizationContext.AuthorizationError == nil { - authorizationContext.AuthorizationError = &response - } - baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) - } else { - user := mappers.DtoUserToApiResponse(*dbUser) - authorizationContext.User = &user - authorizationContext.IsAuthorized = true - authorizationContext.AuthorizedBy = "TokenAuthorization" - } - } else { - response := models.OAuthErrorResponse{ - Error: models.OAuthUnauthorizedClient, - ErrorDescription: "Token is not valid", - } - authorizationContext.IsAuthorized = false - authorizationContext.AuthorizationError = &response - baseCtx.LogError("Request failed to authorize, %v", response.ErrorDescription) } } } + if authorized { + user := mappers.DtoUserToApiResponse(*dbUser) + authorizationContext.User = &user + authorizationContext.IsAuthorized = true + authorizationContext.AuthorizedBy = "TokenAuthorization" + } + ctx := context.WithValue(r.Context(), constants.AUTHORIZATION_CONTEXT_KEY, authorizationContext) if authorizationContext.User != nil { baseCtx.LogInfo("Token Authorization layer finished, user %v authorized", authorizationContext.User.Email) diff --git a/src/restapi/token_role_claim_validation.go b/src/restapi/token_role_claim_validation.go new file mode 100644 index 00000000..f5db8741 --- /dev/null +++ b/src/restapi/token_role_claim_validation.go @@ -0,0 +1,36 @@ +package restapi + +type TokenRoleClaimValidation struct { + Name string + exists bool +} + +func (s *TokenRoleClaimValidation) Exists() bool { + return s.exists +} + +func (s *TokenRoleClaimValidation) SetExists(exists bool) { + s.exists = exists +} + +type TokenRoleClaimValidationList []*TokenRoleClaimValidation + +func (s TokenRoleClaimValidationList) Exists() bool { + for _, item := range s { + if !item.Exists() { + return false + } + } + return true +} + +func (s TokenRoleClaimValidationList) GetFailed() []string { + failed := []string{} + for _, item := range s { + if !item.Exists() { + failed = append(failed, item.Name) + } + } + + return failed +} diff --git a/src/security/brute_force_guard/main.go b/src/security/brute_force_guard/main.go new file mode 100644 index 00000000..1411cf77 --- /dev/null +++ b/src/security/brute_force_guard/main.go @@ -0,0 +1,164 @@ +package bruteforceguard + +import ( + "strconv" + "time" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/Parallels/pd-api-service/config" + "github.com/Parallels/pd-api-service/constants" + "github.com/Parallels/pd-api-service/errors" + "github.com/Parallels/pd-api-service/serviceprovider" +) + +var globalBruteForceGuard *BruteForceGuard + +type BruteForceGuard struct { + ctx basecontext.ApiContext + options *BruteForceGuardOptions +} + +func New(ctx basecontext.ApiContext) *BruteForceGuard { + globalBruteForceGuard = &BruteForceGuard{ + ctx: ctx, + options: NewDefaultOptions(), + } + + globalBruteForceGuard.processEnvironmentVariables() + return globalBruteForceGuard +} + +func Get() *BruteForceGuard { + if globalBruteForceGuard == nil { + ctx := basecontext.NewRootBaseContext() + return New(ctx) + } + + return globalBruteForceGuard +} + +func (s *BruteForceGuard) WithMaxLoginAttempts(maxAttempts int) *BruteForceGuard { + s.options.WithMaxLoginAttempts(maxAttempts) + return s +} + +func (s *BruteForceGuard) WithBlockDuration(duration string) *BruteForceGuard { + s.options.WithBlockDuration(duration) + return s +} + +func (s *BruteForceGuard) WithIncrementalWait(incremental bool) *BruteForceGuard { + s.options.WithIncrementalWait(incremental) + return s +} + +func (s *BruteForceGuard) Options() *BruteForceGuardOptions { + return s.options +} + +func (s *BruteForceGuard) Process(userId string, loginState bool, reason string) *errors.Diagnostics { + diag := errors.NewDiagnostics() + dbService, err := serviceprovider.GetDatabaseService(s.ctx) + if err != nil { + diag.AddError(err) + return diag + } + + user, err := dbService.GetUser(s.ctx, userId) + if err != nil { + diag.AddError(err) + return diag + } + + if user == nil { + diag.AddError(errors.ErrNotFound()) + return diag + } + + if loginState { + user.FailedLoginAttempts = 0 + user.BlockedSince = "" + user.Blocked = false + user.BlockedReason = "" + err := dbService.UpdateUserBlockStatus(s.ctx, *user) + if err != nil { + diag.AddError(err) + } + return diag + } else { + user.FailedLoginAttempts++ + + if user.FailedLoginAttempts >= s.options.MaxLoginAttempts() { + user.Blocked = true + user.BlockedSince = time.Now().Format(time.RFC3339) + user.BlockedReason = reason + err := dbService.UpdateUserBlockStatus(s.ctx, *user) + if err != nil { + diag.AddError(err) + return diag + } + if s.options.IncrementalWait() { + countExtraAttempts := user.FailedLoginAttempts - (s.options.MaxLoginAttempts() - 1) + sleepFor := time.Duration(s.options.BlockDuration().Seconds()*float64(countExtraAttempts)) * time.Second + time.Sleep(sleepFor) + } else { + sleepFor := s.options.BlockDuration() + time.Sleep(sleepFor) + } + } else { + err := dbService.UpdateUserBlockStatus(s.ctx, *user) + if err != nil { + diag.AddError(err) + return diag + } + } + } + + return diag +} + +func (s *BruteForceGuard) IsBlocked(userId string) bool { + dbService, err := serviceprovider.GetDatabaseService(s.ctx) + if err != nil { + return false + } + + user, err := dbService.GetUser(s.ctx, userId) + if err != nil { + return false + } + + if user == nil { + return false + } + + return user.Blocked +} + +func (s *BruteForceGuard) processEnvironmentVariables() { + cfg := config.NewConfig() + if cfg.GetKey(constants.BRUTE_FORCE_MAX_LOGIN_ATTEMPTS_ENV_VAR) != "" { + maxLoginAttempts, err := strconv.Atoi(cfg.GetKey(constants.BRUTE_FORCE_MAX_LOGIN_ATTEMPTS_ENV_VAR)) + if err != nil { + s.ctx.LogWarn("[BruteForceGuard] Invalid value for %s: %s", constants.BRUTE_FORCE_MAX_LOGIN_ATTEMPTS_ENV_VAR, err.Error()) + } else { + s.ctx.LogDebug("[BruteForceGuard] Setting %s to %d", constants.BRUTE_FORCE_MAX_LOGIN_ATTEMPTS_ENV_VAR, maxLoginAttempts) + s.options.WithMaxLoginAttempts(maxLoginAttempts) + } + } + + if cfg.GetKey(constants.BRUTE_FORCE_LOCKOUT_DURATION_ENV_VAR) != "" { + s.ctx.LogDebug("[BruteForceGuard] Setting %s to %s", constants.BRUTE_FORCE_LOCKOUT_DURATION_ENV_VAR, cfg.GetKey(constants.BRUTE_FORCE_LOCKOUT_DURATION_ENV_VAR)) + s.options.WithBlockDuration(cfg.GetKey(constants.BRUTE_FORCE_LOCKOUT_DURATION_ENV_VAR)) + } + + if cfg.GetKey(constants.BRUTE_FORCE_INCREMENTAL_WAIT_ENV_VAR) != "" { + incrementalWait, err := strconv.ParseBool(cfg.GetKey(constants.BRUTE_FORCE_INCREMENTAL_WAIT_ENV_VAR)) + if err != nil { + s.ctx.LogWarn("[BruteForceGuard] Invalid value for %s: %s", constants.BRUTE_FORCE_INCREMENTAL_WAIT_ENV_VAR, err.Error()) + } else { + s.ctx.LogInfo("[BruteForceGuard] Setting %s to %v", constants.BRUTE_FORCE_INCREMENTAL_WAIT_ENV_VAR, incrementalWait) + s.options.WithIncrementalWait(incrementalWait) + } + } +} diff --git a/src/security/brute_force_guard/options.go b/src/security/brute_force_guard/options.go new file mode 100644 index 00000000..7aab687f --- /dev/null +++ b/src/security/brute_force_guard/options.go @@ -0,0 +1,54 @@ +package bruteforceguard + +import ( + "time" +) + +type BruteForceGuardOptions struct { + maxFailedLoginAttempts int + blockDuration string + incrementalWait bool +} + +func NewDefaultOptions() *BruteForceGuardOptions { + return &BruteForceGuardOptions{ + maxFailedLoginAttempts: 5, + blockDuration: "5s", + incrementalWait: true, + } +} + +func (bfg *BruteForceGuardOptions) WithMaxLoginAttempts(attempts int) *BruteForceGuardOptions { + if attempts < 1 { + attempts = 1 + } + bfg.maxFailedLoginAttempts = attempts + return bfg +} + +func (bfg *BruteForceGuardOptions) WithBlockDuration(duration string) *BruteForceGuardOptions { + bfg.blockDuration = duration + return bfg +} + +func (bfg *BruteForceGuardOptions) WithIncrementalWait(incremental bool) *BruteForceGuardOptions { + bfg.incrementalWait = incremental + return bfg +} + +func (bfg *BruteForceGuardOptions) BlockDuration() time.Duration { + duration, err := time.ParseDuration(bfg.blockDuration) + if err != nil { + return time.Second * 5 + } + + return duration +} + +func (bfg *BruteForceGuardOptions) MaxLoginAttempts() int { + return bfg.maxFailedLoginAttempts +} + +func (bfg *BruteForceGuardOptions) IncrementalWait() bool { + return bfg.incrementalWait +} diff --git a/src/security/jwt/main.go b/src/security/jwt/main.go new file mode 100644 index 00000000..51e1da73 --- /dev/null +++ b/src/security/jwt/main.go @@ -0,0 +1,269 @@ +package jwt + +import ( + "errors" + "strconv" + "time" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/Parallels/pd-api-service/config" + "github.com/Parallels/pd-api-service/constants" + "github.com/Parallels/pd-api-service/security" + "github.com/golang-jwt/jwt/v4" + "gopkg.in/square/go-jose.v2" +) + +var globalJwtService *JwtService + +type JwtService struct { + ctx basecontext.ApiContext + Options *JwtOptions +} + +func New(ctx basecontext.ApiContext) *JwtService { + globalJwtService = &JwtService{ + ctx: ctx, + Options: NewDefaultOptions(ctx), + } + + err := globalJwtService.processEnvironmentVariables() + if err != nil { + ctx.LogError("Error processing environment variables for jwt options: %s", err.Error()) + } + + return globalJwtService +} + +func Get() *JwtService { + if globalJwtService == nil { + ctx := basecontext.NewRootBaseContext() + return New(ctx) + } + + return globalJwtService +} + +func (s *JwtService) WithTokenDuration(durationInMinutes float64) *JwtService { + s.Options.WithTokenDuration(durationInMinutes) + return s +} + +func (s *JwtService) WithSecret(secret string) *JwtService { + s.Options.WithSecret(secret) + return s +} + +func (s *JwtService) WithPrivateKey(privateKey string) *JwtService { + s.Options.WithPrivateKey(privateKey) + return s +} + +func (s *JwtService) WithAlgorithm(algorithm JwtSigningAlgorithm) *JwtService { + s.Options.WithAlgorithm(algorithm) + return s +} + +func (s *JwtService) Sign(claims map[string]interface{}) (string, error) { + if claims["email"] == "" { + return "", errors.New("email cannot be empty") + } + + expiresAt := time.Now().Add(s.Options.TokenDuration).Unix() + var method jwt.SigningMethod + + switch s.Options.Algorithm { + case JwtSigningAlgorithmHS256: + method = jwt.SigningMethodHS256 + case JwtSigningAlgorithmHS384: + method = jwt.SigningMethodHS384 + case JwtSigningAlgorithmHS512: + method = jwt.SigningMethodHS512 + case JwtSigningAlgorithmRS256: + method = jwt.SigningMethodRS256 + case JwtSigningAlgorithmRS384: + method = jwt.SigningMethodRS384 + case JwtSigningAlgorithmRS512: + method = jwt.SigningMethodRS512 + default: + method = jwt.SigningMethodHS256 + s.Options.Algorithm = JwtSigningAlgorithmHS256 + } + if claims["roles"] == nil { + claims["roles"] = []string{} + } + if claims["claims"] == nil { + claims["claims"] = map[string]interface{}{} + } + + defaultClaims := jwt.MapClaims{ + "exp": expiresAt, + } + + for k, v := range claims { + defaultClaims[k] = v + } + + token := jwt.NewWithClaims(method, defaultClaims) + + var key interface{} + + switch s.Options.Algorithm { + case JwtSigningAlgorithmHS256, JwtSigningAlgorithmHS384, JwtSigningAlgorithmHS512: + if s.Options.Secret != "" { + key = []byte(s.Options.Secret) + } else { + return "", errors.New("secret cannot be empty") + } + case JwtSigningAlgorithmRS256, JwtSigningAlgorithmRS384, JwtSigningAlgorithmRS512: + if s.Options.PrivateKey != "" { + decodedKey, err := security.Base64Decode(s.Options.PrivateKey) + if err != nil { + return "", err + } + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(decodedKey) + if err != nil { + return "", err + } + + key = privateKey + } else { + return "", errors.New("private key cannot be empty") + } + } + + tokenString, err := token.SignedString(key) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func (s *JwtService) GenerateJWKS() (string, error) { + if s.Options.PrivateKey == "" { + return "", errors.New("private key cannot be empty") + } + + decodedKey, err := security.Base64Decode(s.Options.PrivateKey) + if err != nil { + return "", err + } + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(decodedKey) + if err != nil { + return "", err + } + + var algorithm string + switch s.Options.Algorithm { + case JwtSigningAlgorithmRS256: + algorithm = "RS256" + case JwtSigningAlgorithmRS384: + algorithm = "RS384" + case JwtSigningAlgorithmRS512: + algorithm = "RS512" + default: + algorithm = "RS256" + } + + thumbprint, err := security.CalculatePrivateKeyThumbprint(privateKey) + if err != nil { + return "", err + } + + jwk := jose.JSONWebKey{Key: privateKey, KeyID: thumbprint, Algorithm: algorithm} + + jwkBytes, err := jwk.MarshalJSON() + if err != nil { + return "", err + } + + return string(jwkBytes), nil +} + +func (s *JwtService) Parse(token string) (*JwtSystemToken, error) { + var key interface{} + + switch s.Options.Algorithm { + case JwtSigningAlgorithmHS256, JwtSigningAlgorithmHS384, JwtSigningAlgorithmHS512: + if s.Options.Secret != "" { + key = []byte(s.Options.Secret) + } else { + return nil, errors.New("secret cannot be empty") + } + case JwtSigningAlgorithmRS256, JwtSigningAlgorithmRS384, JwtSigningAlgorithmRS512: + if s.Options.PrivateKey != "" { + decodedKey, err := security.Base64Decode(s.Options.PrivateKey) + if err != nil { + return nil, err + } + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(decodedKey) + if err != nil { + return nil, err + } + publicKey := privateKey.Public() + if err != nil { + return nil, err + } + key = publicKey + } else { + return nil, errors.New("private key cannot be empty") + } + } + + tokenObj, _ := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + return key, nil + }) + + systemToken := &JwtSystemToken{ + token: token, + tokenObj: tokenObj, + } + _, _ = systemToken.GetTokenClaims() + + return systemToken, nil +} + +func (s *JwtService) processEnvironmentVariables() error { + cfg := config.NewConfig() + if cfg.GetKey(constants.JWT_SIGN_ALGORITHM_ENV_VAR) != "" { + algorithm := JwtSigningAlgorithm(cfg.GetKey(constants.JWT_SIGN_ALGORITHM_ENV_VAR)) + switch algorithm { + case JwtSigningAlgorithmHS256, JwtSigningAlgorithmHS384, JwtSigningAlgorithmHS512, + JwtSigningAlgorithmRS256, JwtSigningAlgorithmRS384, JwtSigningAlgorithmRS512: + default: + return errors.New("invalid signing algorithm") + } + + s.Options.WithAlgorithm(algorithm) + } + + if cfg.GetKey(constants.JWT_HMACS_SECRET_ENV_VAR) != "" { + s.Options.WithSecret(cfg.GetKey(constants.JWT_HMACS_SECRET_ENV_VAR)) + } + + if cfg.GetKey(constants.JWT_PRIVATE_KEY_ENV_VAR) != "" { + s.Options.WithPrivateKey(cfg.GetKey(constants.JWT_PRIVATE_KEY_ENV_VAR)) + } + + if cfg.GetKey(constants.JWT_DURATION_ENV_VAR) != "" { + durationInMinutes, err := strconv.ParseFloat(cfg.GetKey(constants.JWT_DURATION_ENV_VAR), 64) + if err != nil { + return err + } + s.Options.WithTokenDuration(durationInMinutes) + } + + // generating a default secret if none is provided + if s.Options.Algorithm == JwtSigningAlgorithmHS256 || s.Options.Algorithm == JwtSigningAlgorithmHS384 || s.Options.Algorithm == JwtSigningAlgorithmHS512 { + if s.Options.Secret == "" { + randStr, err := security.GenerateCryptoRandomString(80) + if err != nil { + s.ctx.LogError("Error generating random string: %s", err.Error()) + return err + } + s.Options.WithSecret(randStr) + } + } + + return nil +} diff --git a/src/security/jwt/main_test.go b/src/security/jwt/main_test.go new file mode 100644 index 00000000..95274891 --- /dev/null +++ b/src/security/jwt/main_test.go @@ -0,0 +1,586 @@ +package jwt + +import ( + "errors" + "os" + "testing" + "time" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/Parallels/pd-api-service/constants" + "github.com/stretchr/testify/assert" +) + +func TestGet(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + globalJwtService = nil + + svc := Get() + assert.NotNil(t, svc) + + assert.Equal(t, ctx, svc.ctx) + + svc2 := Get() + assert.Equal(t, svc, svc2) +} + +func TestJwtService_SignHS256(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJwtService_SignHS384(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + svc.WithAlgorithm(JwtSigningAlgorithmHS384) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJwtService_SignHS512(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + svc.WithAlgorithm(JwtSigningAlgorithmHS512) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJwtService_SignRandomSecret(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJwtService_SignRS256(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + svc.WithAlgorithm(JwtSigningAlgorithmRS256) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + +} + +func TestJwtService_SignRS384(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + svc.WithAlgorithm(JwtSigningAlgorithmRS384) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + +} + +func TestJwtService_SignRS512(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + svc.WithAlgorithm(JwtSigningAlgorithmRS512) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + +} + +func TestJwtService_SignDefault(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + svc.WithAlgorithm("") + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + token, err := svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 2: Sign with empty email + claims["email"] = "" + token, err = svc.Sign(claims) + assert.Error(t, err) + assert.Empty(t, token) + + // Test case 3: Sign with empty roles + claims["email"] = "test@example.com" + claims["roles"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Test case 4: Sign with empty claims + claims["email"] = "test@example.com" + claims["roles"] = []string{"admin", "user"} + claims["claims"] = []string{} + token, err = svc.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, token) + +} + +func TestJwtService_SignNoPrivateKey(t *testing.T) { + // Create a new instance of JwtService + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.WithPrivateKey("") + svc.WithAlgorithm(JwtSigningAlgorithmRS256) + + // Test case 1: Sign with valid input + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + _, err := svc.Sign(claims) + assert.Errorf(t, err, "private key cannot be empty") +} + +func TestGenerateJwksRS256Algorithm(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==" + svc.Options.Algorithm = JwtSigningAlgorithmRS256 + + expectedJWK := `{"kty":"RSA","kid":"ee8146d4b30a57d0053f39c80f4c3caa46461633","alg":"RS256","n":"rRM-1Pib0oeFCGOo5mNjsN8xOrL_n6It0MvuO3wRdpCewJfsV7IRpYnL5NLztWfuh6oklgvq20jM4451AzH3kndkx4XjcOtfh9ZIRo_qBXXxGF9XJWWzms21jZBs95Z_zwTZgpBU_mVhHRz5n9-UGSeEZYfu8ZyLxFFCroBV5RNkNrHC0IbNOVn8RuJzks89vyvS6DKlKcoHfjiUX0LMa5Fqq-nE-9Gy3QDsqdkaKpGBTnakseZB-hEvLPFX1T0TGd8sA48K12Z5YjIZ60ODziNpyvKn_n3YojZ1Lfz9C1Rl603qoenc1wcxYb-5XlXzjUd7JSFqfTOHtx_mcorzEQ","e":"AQAB","d":"bMEEIVUarP7FJFFjR2m6seB9kbH6mHeTLHmIeE5stsEHUGPmefCF0Cw3N9EqnJnzM8JA_Rv99s7XGEJi2qAiPiHR7OH_2eu8-qE2h0hVOBs1ZSg7nV87rZGHfK39GtKx-wbEGpvRHI3dqSqU3NXjvK6tLhNtnNrOpIyfRwGTd75f8JaDLT8Nu1jhWD5ZRWua2vq_yklFsDuPhePh2fjZJ6eAqIx5I3gkpwKue8IuWd5PTwvY3CAxH3qlWy9aGWDxohS-zL0mWOXg82ZM8jE1fEP0Y_kFmvR8lfQU7HYyi6HAxKatYCuWaRf83OT5F3R7_adCTW7eZjSvcpKuhOsBFQ","p":"2dB5m17kN-H90rN7dNZ9fMf_t7FFV4LxuYA2MvwZYC8SFGaMmdgx20LGNBk4Dx9Mz4p8t06CY_9T_0Djk9F8URlrQJl7tzVN8Qq-FHcnjAY9zV_DINfbAWoQ4ZQT3I8syiMbAK8uOK7ehXbR9IbaikfJjJc5d9Vg2BNAACGRL68","q":"y2rfZzUlw2JScRhq9T0BzVuFVSOpad13qKzVDs13EmFoGqk31Zt1EnvaiwU7uGu2bGTTgQK65V90C5zgZW07yz1A7tSVGEDWgi3rYznMQzCUUHXRx7cX9i5G8OnBDKFznVvGXKSOFVq3AP2cOkr3xmF25ST--xwG2Ca0lTNB-T8","dp":"oXQyICRHoNDIuB1IvwObAxqxB7XEg6jRi0JpaoOKP8zEZxDY2dTyp-eoScgD0NnPsuuhpLLyXjNOTSAJUXHv56Gi6cCbfuNpQepHmZ31V4rs1sZMOpUmhrbbioqb6lrKxY8eHfS8m1GsKlw4Jzyq0-OAl9EkzRoC7kfeofo_x4s","dq":"WuV9wJuiLUWxOzJDESTauk4MLXhLCrBY-PmKFxw--eqm30sAVSYrUUAg7wA-qHERSixfyoVSyI43x7ypFQmTr4TGkDJUEUtfzzn_tg4stVVu4OlU_V5Wib4yGxMJHcDDbeyFnf42M1qe7gVlmzLGt1H0E_7NJZ5nfI0HIqiN8Xc","qi":"umm077LJtRZ8A19-xYSglcc32423DS-nHz5s3sW-EdX3wKJCG1A4zFpQ3gNTcBtIlBMIcAQYspJVE78cxQCoZ5D9Br2BJBcTLOS0tOvRzSyacXtiBDK0d8dZ_8rf1pfDg1Gt-isN9JaUgPITeT7b-WaBBLFz_f5mWH0vK97fnAQ"}` + + jwk, err := svc.GenerateJWKS() + assert.NoError(t, err) + assert.Equal(t, expectedJWK, jwk) +} + +func TestGenerateJwksRS384Algorithm(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==" + svc.Options.Algorithm = JwtSigningAlgorithmRS384 + + expectedJWK := `{"kty":"RSA","kid":"ee8146d4b30a57d0053f39c80f4c3caa46461633","alg":"RS384","n":"rRM-1Pib0oeFCGOo5mNjsN8xOrL_n6It0MvuO3wRdpCewJfsV7IRpYnL5NLztWfuh6oklgvq20jM4451AzH3kndkx4XjcOtfh9ZIRo_qBXXxGF9XJWWzms21jZBs95Z_zwTZgpBU_mVhHRz5n9-UGSeEZYfu8ZyLxFFCroBV5RNkNrHC0IbNOVn8RuJzks89vyvS6DKlKcoHfjiUX0LMa5Fqq-nE-9Gy3QDsqdkaKpGBTnakseZB-hEvLPFX1T0TGd8sA48K12Z5YjIZ60ODziNpyvKn_n3YojZ1Lfz9C1Rl603qoenc1wcxYb-5XlXzjUd7JSFqfTOHtx_mcorzEQ","e":"AQAB","d":"bMEEIVUarP7FJFFjR2m6seB9kbH6mHeTLHmIeE5stsEHUGPmefCF0Cw3N9EqnJnzM8JA_Rv99s7XGEJi2qAiPiHR7OH_2eu8-qE2h0hVOBs1ZSg7nV87rZGHfK39GtKx-wbEGpvRHI3dqSqU3NXjvK6tLhNtnNrOpIyfRwGTd75f8JaDLT8Nu1jhWD5ZRWua2vq_yklFsDuPhePh2fjZJ6eAqIx5I3gkpwKue8IuWd5PTwvY3CAxH3qlWy9aGWDxohS-zL0mWOXg82ZM8jE1fEP0Y_kFmvR8lfQU7HYyi6HAxKatYCuWaRf83OT5F3R7_adCTW7eZjSvcpKuhOsBFQ","p":"2dB5m17kN-H90rN7dNZ9fMf_t7FFV4LxuYA2MvwZYC8SFGaMmdgx20LGNBk4Dx9Mz4p8t06CY_9T_0Djk9F8URlrQJl7tzVN8Qq-FHcnjAY9zV_DINfbAWoQ4ZQT3I8syiMbAK8uOK7ehXbR9IbaikfJjJc5d9Vg2BNAACGRL68","q":"y2rfZzUlw2JScRhq9T0BzVuFVSOpad13qKzVDs13EmFoGqk31Zt1EnvaiwU7uGu2bGTTgQK65V90C5zgZW07yz1A7tSVGEDWgi3rYznMQzCUUHXRx7cX9i5G8OnBDKFznVvGXKSOFVq3AP2cOkr3xmF25ST--xwG2Ca0lTNB-T8","dp":"oXQyICRHoNDIuB1IvwObAxqxB7XEg6jRi0JpaoOKP8zEZxDY2dTyp-eoScgD0NnPsuuhpLLyXjNOTSAJUXHv56Gi6cCbfuNpQepHmZ31V4rs1sZMOpUmhrbbioqb6lrKxY8eHfS8m1GsKlw4Jzyq0-OAl9EkzRoC7kfeofo_x4s","dq":"WuV9wJuiLUWxOzJDESTauk4MLXhLCrBY-PmKFxw--eqm30sAVSYrUUAg7wA-qHERSixfyoVSyI43x7ypFQmTr4TGkDJUEUtfzzn_tg4stVVu4OlU_V5Wib4yGxMJHcDDbeyFnf42M1qe7gVlmzLGt1H0E_7NJZ5nfI0HIqiN8Xc","qi":"umm077LJtRZ8A19-xYSglcc32423DS-nHz5s3sW-EdX3wKJCG1A4zFpQ3gNTcBtIlBMIcAQYspJVE78cxQCoZ5D9Br2BJBcTLOS0tOvRzSyacXtiBDK0d8dZ_8rf1pfDg1Gt-isN9JaUgPITeT7b-WaBBLFz_f5mWH0vK97fnAQ"}` + + jwk, err := svc.GenerateJWKS() + assert.NoError(t, err) + assert.Equal(t, expectedJWK, jwk) +} + +func TestGenerateJwksRS512Algorithm(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==" + svc.Options.Algorithm = JwtSigningAlgorithmRS512 + + expectedJWK := `{"kty":"RSA","kid":"ee8146d4b30a57d0053f39c80f4c3caa46461633","alg":"RS512","n":"rRM-1Pib0oeFCGOo5mNjsN8xOrL_n6It0MvuO3wRdpCewJfsV7IRpYnL5NLztWfuh6oklgvq20jM4451AzH3kndkx4XjcOtfh9ZIRo_qBXXxGF9XJWWzms21jZBs95Z_zwTZgpBU_mVhHRz5n9-UGSeEZYfu8ZyLxFFCroBV5RNkNrHC0IbNOVn8RuJzks89vyvS6DKlKcoHfjiUX0LMa5Fqq-nE-9Gy3QDsqdkaKpGBTnakseZB-hEvLPFX1T0TGd8sA48K12Z5YjIZ60ODziNpyvKn_n3YojZ1Lfz9C1Rl603qoenc1wcxYb-5XlXzjUd7JSFqfTOHtx_mcorzEQ","e":"AQAB","d":"bMEEIVUarP7FJFFjR2m6seB9kbH6mHeTLHmIeE5stsEHUGPmefCF0Cw3N9EqnJnzM8JA_Rv99s7XGEJi2qAiPiHR7OH_2eu8-qE2h0hVOBs1ZSg7nV87rZGHfK39GtKx-wbEGpvRHI3dqSqU3NXjvK6tLhNtnNrOpIyfRwGTd75f8JaDLT8Nu1jhWD5ZRWua2vq_yklFsDuPhePh2fjZJ6eAqIx5I3gkpwKue8IuWd5PTwvY3CAxH3qlWy9aGWDxohS-zL0mWOXg82ZM8jE1fEP0Y_kFmvR8lfQU7HYyi6HAxKatYCuWaRf83OT5F3R7_adCTW7eZjSvcpKuhOsBFQ","p":"2dB5m17kN-H90rN7dNZ9fMf_t7FFV4LxuYA2MvwZYC8SFGaMmdgx20LGNBk4Dx9Mz4p8t06CY_9T_0Djk9F8URlrQJl7tzVN8Qq-FHcnjAY9zV_DINfbAWoQ4ZQT3I8syiMbAK8uOK7ehXbR9IbaikfJjJc5d9Vg2BNAACGRL68","q":"y2rfZzUlw2JScRhq9T0BzVuFVSOpad13qKzVDs13EmFoGqk31Zt1EnvaiwU7uGu2bGTTgQK65V90C5zgZW07yz1A7tSVGEDWgi3rYznMQzCUUHXRx7cX9i5G8OnBDKFznVvGXKSOFVq3AP2cOkr3xmF25ST--xwG2Ca0lTNB-T8","dp":"oXQyICRHoNDIuB1IvwObAxqxB7XEg6jRi0JpaoOKP8zEZxDY2dTyp-eoScgD0NnPsuuhpLLyXjNOTSAJUXHv56Gi6cCbfuNpQepHmZ31V4rs1sZMOpUmhrbbioqb6lrKxY8eHfS8m1GsKlw4Jzyq0-OAl9EkzRoC7kfeofo_x4s","dq":"WuV9wJuiLUWxOzJDESTauk4MLXhLCrBY-PmKFxw--eqm30sAVSYrUUAg7wA-qHERSixfyoVSyI43x7ypFQmTr4TGkDJUEUtfzzn_tg4stVVu4OlU_V5Wib4yGxMJHcDDbeyFnf42M1qe7gVlmzLGt1H0E_7NJZ5nfI0HIqiN8Xc","qi":"umm077LJtRZ8A19-xYSglcc32423DS-nHz5s3sW-EdX3wKJCG1A4zFpQ3gNTcBtIlBMIcAQYspJVE78cxQCoZ5D9Br2BJBcTLOS0tOvRzSyacXtiBDK0d8dZ_8rf1pfDg1Gt-isN9JaUgPITeT7b-WaBBLFz_f5mWH0vK97fnAQ"}` + + jwk, err := svc.GenerateJWKS() + assert.NoError(t, err) + assert.Equal(t, expectedJWK, jwk) +} + +func TestGenerateJwksNoKeyAlgorithm(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==" + svc.Options.Algorithm = "" + + expectedJWK := `{"kty":"RSA","kid":"ee8146d4b30a57d0053f39c80f4c3caa46461633","alg":"RS256","n":"rRM-1Pib0oeFCGOo5mNjsN8xOrL_n6It0MvuO3wRdpCewJfsV7IRpYnL5NLztWfuh6oklgvq20jM4451AzH3kndkx4XjcOtfh9ZIRo_qBXXxGF9XJWWzms21jZBs95Z_zwTZgpBU_mVhHRz5n9-UGSeEZYfu8ZyLxFFCroBV5RNkNrHC0IbNOVn8RuJzks89vyvS6DKlKcoHfjiUX0LMa5Fqq-nE-9Gy3QDsqdkaKpGBTnakseZB-hEvLPFX1T0TGd8sA48K12Z5YjIZ60ODziNpyvKn_n3YojZ1Lfz9C1Rl603qoenc1wcxYb-5XlXzjUd7JSFqfTOHtx_mcorzEQ","e":"AQAB","d":"bMEEIVUarP7FJFFjR2m6seB9kbH6mHeTLHmIeE5stsEHUGPmefCF0Cw3N9EqnJnzM8JA_Rv99s7XGEJi2qAiPiHR7OH_2eu8-qE2h0hVOBs1ZSg7nV87rZGHfK39GtKx-wbEGpvRHI3dqSqU3NXjvK6tLhNtnNrOpIyfRwGTd75f8JaDLT8Nu1jhWD5ZRWua2vq_yklFsDuPhePh2fjZJ6eAqIx5I3gkpwKue8IuWd5PTwvY3CAxH3qlWy9aGWDxohS-zL0mWOXg82ZM8jE1fEP0Y_kFmvR8lfQU7HYyi6HAxKatYCuWaRf83OT5F3R7_adCTW7eZjSvcpKuhOsBFQ","p":"2dB5m17kN-H90rN7dNZ9fMf_t7FFV4LxuYA2MvwZYC8SFGaMmdgx20LGNBk4Dx9Mz4p8t06CY_9T_0Djk9F8URlrQJl7tzVN8Qq-FHcnjAY9zV_DINfbAWoQ4ZQT3I8syiMbAK8uOK7ehXbR9IbaikfJjJc5d9Vg2BNAACGRL68","q":"y2rfZzUlw2JScRhq9T0BzVuFVSOpad13qKzVDs13EmFoGqk31Zt1EnvaiwU7uGu2bGTTgQK65V90C5zgZW07yz1A7tSVGEDWgi3rYznMQzCUUHXRx7cX9i5G8OnBDKFznVvGXKSOFVq3AP2cOkr3xmF25ST--xwG2Ca0lTNB-T8","dp":"oXQyICRHoNDIuB1IvwObAxqxB7XEg6jRi0JpaoOKP8zEZxDY2dTyp-eoScgD0NnPsuuhpLLyXjNOTSAJUXHv56Gi6cCbfuNpQepHmZ31V4rs1sZMOpUmhrbbioqb6lrKxY8eHfS8m1GsKlw4Jzyq0-OAl9EkzRoC7kfeofo_x4s","dq":"WuV9wJuiLUWxOzJDESTauk4MLXhLCrBY-PmKFxw--eqm30sAVSYrUUAg7wA-qHERSixfyoVSyI43x7ypFQmTr4TGkDJUEUtfzzn_tg4stVVu4OlU_V5Wib4yGxMJHcDDbeyFnf42M1qe7gVlmzLGt1H0E_7NJZ5nfI0HIqiN8Xc","qi":"umm077LJtRZ8A19-xYSglcc32423DS-nHz5s3sW-EdX3wKJCG1A4zFpQ3gNTcBtIlBMIcAQYspJVE78cxQCoZ5D9Br2BJBcTLOS0tOvRzSyacXtiBDK0d8dZ_8rf1pfDg1Gt-isN9JaUgPITeT7b-WaBBLFz_f5mWH0vK97fnAQ"}` + + jwk, err := svc.GenerateJWKS() + assert.NoError(t, err) + assert.Equal(t, expectedJWK, jwk) +} + +func TestGenerateJWKSEmptyPrivateKey(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "" + + _, err := svc.GenerateJWKS() + assert.EqualError(t, err, "private key cannot be empty") +} + +func TestGenerateJWKSInvalidPrivateKey(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "invalidPrivateKey" + + _, err := svc.GenerateJWKS() + assert.Error(t, err) +} + +func TestGenerateJWKSRS384Algorithm(t *testing.T) { + svc := New(nil) + svc.Options.PrivateKey = "cGFzc3dvcmQ=" + svc.Options.Algorithm = JwtSigningAlgorithmRS384 + + _, err := svc.GenerateJWKS() + assert.Error(t, err, "private key cannot be empty") +} + +func TestVerifyJwtWithHS256Algorithm(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claims) + assert.NoError(t, err) + + token, err := svc.Parse(tokenStr) + assert.NoError(t, err) + + verifiedToken, err := token.Valid() + assert.NoError(t, err) + + assert.True(t, verifiedToken) +} + +func TestVerifyJwtWithHS256AlgorithmWithNoSecret(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + svc.WithSecret("secret") + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claims) + assert.NoError(t, err) + + svc.Options.Secret = "" + _, err = svc.Parse(tokenStr) + assert.Errorf(t, err, "secret cannot be empty") +} + +func TestVerifyJwtWithRS256Algorithm(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + svc.Options.Algorithm = JwtSigningAlgorithmRS256 + + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claims) + assert.NoError(t, err) + + token, err := svc.Parse(tokenStr) + assert.NoError(t, err) + + verifiedToken, err := token.Valid() + assert.NoError(t, err) + + assert.True(t, verifiedToken) +} + +func TestVerifyJwtWithRS256AlgorithmWithNoPrivateKey(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + svc.Options.Algorithm = JwtSigningAlgorithmRS256 + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + + claims := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claims) + assert.NoError(t, err) + + svc.Options.PrivateKey = "" + _, err = svc.Parse(tokenStr) + + assert.Errorf(t, err, "") +} + +func TestVerifyJJwtWithNoRolesAndClaims(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + svc.Options.Algorithm = JwtSigningAlgorithmRS256 + svc.WithPrivateKey("LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBclJNKzFQaWIwb2VGQ0dPbzVtTmpzTjh4T3JML242SXQwTXZ1TzN3UmRwQ2V3SmZzClY3SVJwWW5MNU5MenRXZnVoNm9rbGd2cTIwak00NDUxQXpIM2tuZGt4NFhqY090Zmg5WklSby9xQlhYeEdGOVgKSldXem1zMjFqWkJzOTVaL3p3VFpncEJVL21WaEhSejVuOStVR1NlRVpZZnU4WnlMeEZGQ3JvQlY1Uk5rTnJIQwowSWJOT1ZuOFJ1Snprczg5dnl2UzZES2xLY29IZmppVVgwTE1hNUZxcStuRSs5R3kzUURzcWRrYUtwR0JUbmFrCnNlWkIraEV2TFBGWDFUMFRHZDhzQTQ4SzEyWjVZaklaNjBPRHppTnB5dktuL24zWW9qWjFMZno5QzFSbDYwM3EKb2VuYzF3Y3hZYis1WGxYempVZDdKU0ZxZlRPSHR4L21jb3J6RVFJREFRQUJBb0lCQUd6QkJDRlZHcXoreFNSUgpZMGRwdXJIZ2ZaR3grcGgza3l4NWlIaE9iTGJCQjFCajVubndoZEFzTnpmUktweVo4elBDUVAwYi9mYk8xeGhDCll0cWdJajRoMGV6aC85bnJ2UHFoTm9kSVZUZ2JOV1VvTzUxZk82MlJoM3l0L1JyU3Nmc0d4QnFiMFJ5TjNha3EKbE56VjQ3eXVyUzRUYlp6YXpxU01uMGNCazNlK1gvQ1dneTAvRGJ0WTRWZytXVVZybXRyNnY4cEpSYkE3ajRYago0ZG40MlNlbmdLaU1lU040SktjQ3JudkNMbG5lVDA4TDJOd2dNUjk2cFZzdldobGc4YUlVdnN5OUpsamw0UE5tClRQSXhOWHhEOUdQNUJacjBmSlgwRk94Mk1vdWh3TVNtcldBcmxta1gvTnprK1JkMGUvMm5RazF1M21ZMHIzS1MKcm9UckFSVUNnWUVBMmRCNW0xN2tOK0g5MHJON2ROWjlmTWYvdDdGRlY0THh1WUEyTXZ3WllDOFNGR2FNbWRneAoyMExHTkJrNER4OU16NHA4dDA2Q1kvOVQvMERqazlGOFVSbHJRSmw3dHpWTjhRcStGSGNuakFZOXpWL0RJTmZiCkFXb1E0WlFUM0k4c3lpTWJBSzh1T0s3ZWhYYlI5SWJhaWtmSmpKYzVkOVZnMkJOQUFDR1JMNjhDZ1lFQXkycmYKWnpVbHcySlNjUmhxOVQwQnpWdUZWU09wYWQxM3FLelZEczEzRW1Gb0dxazMxWnQxRW52YWl3VTd1R3UyYkdUVApnUUs2NVY5MEM1emdaVzA3eXoxQTd0U1ZHRURXZ2kzcll6bk1RekNVVUhYUng3Y1g5aTVHOE9uQkRLRnpuVnZHClhLU09GVnEzQVAyY09rcjN4bUYyNVNUKyt4d0cyQ2EwbFROQitUOENnWUVBb1hReUlDUkhvTkRJdUIxSXZ3T2IKQXhxeEI3WEVnNmpSaTBKcGFvT0tQOHpFWnhEWTJkVHlwK2VvU2NnRDBOblBzdXVocExMeVhqTk9UU0FKVVhIdgo1NkdpNmNDYmZ1TnBRZXBIbVozMVY0cnMxc1pNT3BVbWhyYmJpb3FiNmxyS3hZOGVIZlM4bTFHc0tsdzRKenlxCjArT0FsOUVrelJvQzdrZmVvZm8veDRzQ2dZQmE1WDNBbTZJdFJiRTdNa01SSk5xNlRnd3RlRXNLc0ZqNCtZb1gKSEQ3NTZxYmZTd0JWSml0UlFDRHZBRDZvY1JGS0xGL0toVkxJampmSHZLa1ZDWk92aE1hUU1sUVJTMS9QT2YrMgpEaXkxVlc3ZzZWVDlYbGFKdmpJYkV3a2R3TU50N0lXZC9qWXpXcDd1QldXYk1zYTNVZlFUL3MwbG5tZDhqUWNpCnFJM3hkd0tCZ1FDNmFiVHZzc20xRm53RFgzN0ZoS0NWeHpmYmpiY05MNmNmUG16ZXhiNFIxZmZBb2tJYlVEak0KV2xEZUExTndHMGlVRXdod0JCaXlrbFVUdnh6RkFLaG5rUDBHdllFa0Z4TXM1TFMwNjlITkxKcHhlMklFTXJSMwp4MW4veXQvV2w4T0RVYTM2S3czMGxwU0E4aE41UHR2NVpvRUVzWFA5L21aWWZTOHIzdCtjQkE9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + + claims := map[string]interface{}{ + "email": "test@example.com", + } + + tokenStr, err := svc.Sign(claims) + assert.NoError(t, err) + + _, err = svc.Parse(tokenStr) + assert.NoError(t, err) +} + +func TestJwtService_processEnvironmentVariables(t *testing.T) { + ctx := basecontext.NewBaseContext() + svc := New(ctx) + + t.Run("SetAlgorithm", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_SIGN_ALGORITHM_ENV_VAR, "HS256") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, JwtSigningAlgorithmHS256, svc.Options.Algorithm) + }) + + t.Run("SetInvalidAlgorithm", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_SIGN_ALGORITHM_ENV_VAR, "invalid") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Error(t, err) + assert.Equal(t, errors.New("invalid signing algorithm"), err) + }) + + t.Run("SetSecret", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_HMACS_SECRET_ENV_VAR, "secret") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, "secret", svc.Options.Secret) + }) + + t.Run("SetPrivateKey", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_PRIVATE_KEY_ENV_VAR, "private_key") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, "private_key", svc.Options.PrivateKey) + }) + + t.Run("SetTokenDuration", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_DURATION_ENV_VAR, "60") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, time.Duration(60)*time.Minute, svc.Options.TokenDuration) + }) + + t.Run("InvalidTokenDuration", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.JWT_DURATION_ENV_VAR, "invalid") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Error(t, err) + }) +} diff --git a/src/security/jwt/options.go b/src/security/jwt/options.go new file mode 100644 index 00000000..d2081d9c --- /dev/null +++ b/src/security/jwt/options.go @@ -0,0 +1,58 @@ +package jwt + +import ( + "time" + + "github.com/Parallels/pd-api-service/basecontext" +) + +type JwtSigningAlgorithm string + +const ( + JwtSigningAlgorithmHS256 JwtSigningAlgorithm = "HS256" + JwtSigningAlgorithmHS384 JwtSigningAlgorithm = "HS384" + JwtSigningAlgorithmHS512 JwtSigningAlgorithm = "HS512" + JwtSigningAlgorithmRS256 JwtSigningAlgorithm = "RS256" + JwtSigningAlgorithmRS384 JwtSigningAlgorithm = "RS384" + JwtSigningAlgorithmRS512 JwtSigningAlgorithm = "RS512" +) + +type JwtOptions struct { + ctx basecontext.ApiContext + Algorithm JwtSigningAlgorithm + Secret string + PrivateKey string + TokenDuration time.Duration +} + +func NewDefaultOptions(ctx basecontext.ApiContext) *JwtOptions { + if ctx == nil { + ctx = basecontext.NewRootBaseContext() + } + + return &JwtOptions{ + ctx: ctx, + Algorithm: JwtSigningAlgorithmHS256, + TokenDuration: time.Duration(20) * time.Minute, + } +} + +func (o *JwtOptions) WithAlgorithm(algorithm JwtSigningAlgorithm) *JwtOptions { + o.Algorithm = algorithm + return o +} + +func (o *JwtOptions) WithSecret(secret string) *JwtOptions { + o.Secret = secret + return o +} + +func (o *JwtOptions) WithPrivateKey(privateKey string) *JwtOptions { + o.PrivateKey = privateKey + return o +} + +func (o *JwtOptions) WithTokenDuration(durationInMinutes float64) *JwtOptions { + o.TokenDuration = time.Duration(durationInMinutes) * time.Minute + return o +} diff --git a/src/security/jwt/options_test.go b/src/security/jwt/options_test.go new file mode 100644 index 00000000..eb87217e --- /dev/null +++ b/src/security/jwt/options_test.go @@ -0,0 +1,52 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/Parallels/pd-api-service/basecontext" +) + +func TestNewDefaultOptions(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + options := NewDefaultOptions(ctx) + if options == nil { + t.Errorf("NewDefaultOptions returned nil") + } +} + +func TestWithAlgorithm(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + options := NewDefaultOptions(ctx) + options.WithAlgorithm(JwtSigningAlgorithmHS256) + if options.Algorithm != JwtSigningAlgorithmHS256 { + t.Errorf("WithAlgorithm did not set Algorithm") + } +} + +func TestWithSecret(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + options := NewDefaultOptions(ctx) + options.WithSecret("secret") + if options.Secret != "secret" { + t.Errorf("WithSecret did not set Secret") + } +} + +func TestWithPrivateKey(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + options := NewDefaultOptions(ctx) + options.WithPrivateKey("privateKey") + if options.PrivateKey != "privateKey" { + t.Errorf("WithPrivateKey did not set PrivateKey") + } +} + +func TestWithTokenDuration(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + options := NewDefaultOptions(ctx) + options.WithTokenDuration(20) + if options.TokenDuration != time.Duration(20)*time.Minute { + t.Errorf("WithTokenDuration did not set TokenDuration") + } +} diff --git a/src/security/jwt/token.go b/src/security/jwt/token.go new file mode 100644 index 00000000..a4536518 --- /dev/null +++ b/src/security/jwt/token.go @@ -0,0 +1,89 @@ +package jwt + +import ( + "time" + + "github.com/Parallels/pd-api-service/errors" + "github.com/golang-jwt/jwt/v4" +) + +type JwtSystemToken struct { + token string + tokenObj *jwt.Token + Claims map[string]interface{} +} + +func (s *JwtSystemToken) Valid() (bool, error) { + if s.tokenObj == nil { + return false, errors.New("tokenObj is nil") + } + + if _, ok := s.tokenObj.Claims.(jwt.MapClaims); !ok { + return false, errors.New("invalid claims") + } + + if err := s.tokenObj.Claims.Valid(); err != nil { + return false, err + } + + return s.tokenObj.Valid, nil +} + +func (s *JwtSystemToken) GetTokenClaims() (map[string]interface{}, error) { + claims, ok := s.tokenObj.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("invalid claims") + } + + s.Claims = claims + return claims, nil +} + +func (s *JwtSystemToken) GetEmail() (string, error) { + if s.Claims == nil { + _, err := s.GetTokenClaims() + if err != nil { + return "", err + } + } + + email, ok := s.Claims["email"].(string) + if !ok { + return "", errors.New("invalid email") + } + + return email, nil +} + +func (s *JwtSystemToken) GetExpiresAt() (time.Time, error) { + if s.Claims == nil { + _, err := s.GetTokenClaims() + if err != nil { + return time.Time{}, err + } + } + + expiresAt, ok := s.Claims["exp"].(float64) + if !ok { + return time.Time{}, errors.New("invalid expiresAt") + } + + parsedTime := time.Unix(int64(expiresAt), 0) + return parsedTime, nil +} + +func (s *JwtSystemToken) GetClaim(key string) (interface{}, error) { + if s.Claims == nil { + _, err := s.GetTokenClaims() + if err != nil { + return nil, err + } + } + + claim, ok := s.Claims[key] + if !ok { + return nil, errors.New("invalid claim") + } + + return claim, nil +} diff --git a/src/security/jwt/token_test.go b/src/security/jwt/token_test.go new file mode 100644 index 00000000..3381d998 --- /dev/null +++ b/src/security/jwt/token_test.go @@ -0,0 +1,188 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" +) + +func SetupToken(t *testing.T) *JwtSystemToken { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.Options.WithSecret("secret") + + // Test case 1: Sign with valid input + claimsMap := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claimsMap) + assert.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + token, err := svc.Parse(tokenStr) + assert.NoError(t, err) + + return token +} + +func TestValid(t *testing.T) { + token := SetupToken(t) + + valid, err := token.Valid() + assert.NoError(t, err) + assert.True(t, valid) +} + +func TestValidWithError(t *testing.T) { + token := SetupToken(t) + token.tokenObj = nil + + valid, err := token.Valid() + assert.Errorf(t, err, "error: tokenObj is nil") + assert.False(t, valid) +} + +func TestValidExpired(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.Options.WithSecret("secret") + svc.Options.WithTokenDuration(0.1) + + // Test case 1: Sign with valid input + claimsMap := map[string]interface{}{ + "email": "test@example.com", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claimsMap) + assert.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + token, err := svc.Parse(tokenStr) + assert.NoError(t, err) + time.Sleep(10 * time.Second) + + valid, err := token.Valid() + assert.Errorf(t, err, "Token is expired") + assert.False(t, valid) +} + +func TestGetClaims(t *testing.T) { + token := SetupToken(t) + + claims, err := token.GetTokenClaims() + assert.NoError(t, err) + assert.Equal(t, "test@example.com", claims["email"]) +} + +func TestGetEmail(t *testing.T) { + token := SetupToken(t) + + email, err := token.GetEmail() + assert.NoError(t, err) + assert.Equal(t, "test@example.com", email) +} + +func TestGetEmailNoClaim(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = nil + + email, err := token.GetEmail() + assert.Errorf(t, err, "invalid claims") + assert.Equal(t, "", email) +} + +func TestGetEmailWrongFormat(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = jwt.MapClaims{ + "email": 2, + } + + email, err := token.GetEmail() + assert.Errorf(t, err, "invalid email") + assert.Equal(t, "", email) +} + +func TestGetExpiresAt(t *testing.T) { + token := SetupToken(t) + + expiresAt, err := token.GetExpiresAt() + assert.NoError(t, err) + assert.True(t, time.Now().Before(expiresAt)) +} + +func TestGetExpiresAtNoClaim(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = nil + + expiresAt, err := token.GetExpiresAt() + assert.Errorf(t, err, "invalid claims") + assert.False(t, time.Now().Before(expiresAt)) +} + +func TestGetExpiresAtWrongFormat(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = jwt.MapClaims{ + "exp": "wrong", + } + + expiresAt, err := token.GetExpiresAt() + assert.Errorf(t, err, "invalid expiresAt") + assert.False(t, time.Now().Before(expiresAt)) +} + +func TestGetClaim(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.Options.WithSecret("secret") + + // Test case 1: Sign with valid input + claimsMap := map[string]interface{}{ + "email": "test@example.com", + "uid": "1234567890", + "roles": []string{"admin", "user"}, + "claims": []string{"claim1", "claim2"}, + } + + tokenStr, err := svc.Sign(claimsMap) + assert.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + token, err := svc.Parse(tokenStr) + assert.NoError(t, err) + + claim, err := token.GetClaim("uid") + assert.NoError(t, err) + assert.Equal(t, "1234567890", claim) +} + +func TestGetClaimNoClaim(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = nil + + _, err := token.GetClaim("uid") + assert.Errorf(t, err, "invalid claims") +} + +func TestGetClaimWrongFormat(t *testing.T) { + token := SetupToken(t) + token.Claims = nil + token.tokenObj.Claims = jwt.MapClaims{ + "exp": "wrong", + } + + _, err := token.GetClaim("test") + assert.Errorf(t, err, "invalid claim") +} diff --git a/src/security/main.go b/src/security/main.go index 0b0e80e2..153b58d0 100644 --- a/src/security/main.go +++ b/src/security/main.go @@ -5,20 +5,41 @@ import ( "crypto/cipher" "crypto/rand" "crypto/rsa" + "crypto/sha1" "crypto/x509" + "encoding/base64" + "encoding/hex" "encoding/pem" "io" "os" "github.com/Parallels/pd-api-service/errors" + cryptorand "github.com/cjlapao/common-go-cryptorand" ) -func GenPrivateRsaKey(filename string) error { +func GenerateCryptoRandomString(length int) (string, error) { + if length <= 0 { + return "", errors.New("length must be greater than 0") + } + + result, err := cryptorand.GetAlphaNumericRandomString(length) + if err != nil { + return "", err + } + + return result, nil +} + +func GenPrivateRsaKey(filename string, size int) error { if filename == "" { return errors.New("filename is empty") } - privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if size <= 0 { + size = 2048 + } + + privKey, err := rsa.GenerateKey(rand.Reader, size) if err != nil { return err } @@ -109,3 +130,21 @@ func DecryptString(privateKey string, cipherText []byte) (string, error) { return string(plaintext), nil } + +func Base64Decode(input string) ([]byte, error) { + decoded, err := base64.StdEncoding.DecodeString(input) + if err != nil { + return nil, err + } + return decoded, nil +} + +func Base64Encode(input []byte) string { + return base64.StdEncoding.EncodeToString(input) +} + +func CalculatePrivateKeyThumbprint(privateKey *rsa.PrivateKey) (string, error) { + publicKeyDer := x509.MarshalPKCS1PublicKey(&privateKey.PublicKey) + hash := sha1.Sum(publicKeyDer) + return hex.EncodeToString(hash[:]), nil +} diff --git a/src/security/password/main.go b/src/security/password/main.go new file mode 100644 index 00000000..e43d31c4 --- /dev/null +++ b/src/security/password/main.go @@ -0,0 +1,249 @@ +package password + +import ( + "crypto/sha256" + "encoding/hex" + "strconv" + "strings" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/Parallels/pd-api-service/config" + "github.com/Parallels/pd-api-service/constants" + "github.com/Parallels/pd-api-service/errors" + "golang.org/x/crypto/bcrypt" +) + +var globalPasswordService *PasswordService + +const ( + SPECIAL_CHARACTERS = "!@#$%^&*()_+.?" +) + +type PasswordHashingAlgorithm string + +const ( + PasswordHashingAlgorithmBCrypt PasswordHashingAlgorithm = "bcrypt" + PasswordHashingAlgorithmSHA256 PasswordHashingAlgorithm = "sha256" +) + +type PasswordService struct { + ctx basecontext.ApiContext + HashingAlgorithm PasswordHashingAlgorithm + options *PasswordComplexityOptions +} + +func Get() *PasswordService { + if globalPasswordService == nil { + ctx := basecontext.NewRootBaseContext() + return New(ctx) + } + + return globalPasswordService +} + +func New(ctx basecontext.ApiContext) *PasswordService { + globalPasswordService = &PasswordService{ + ctx: ctx, + HashingAlgorithm: PasswordHashingAlgorithmBCrypt, + options: NewPasswordComplexityOptions(ctx), + } + + err := globalPasswordService.processEnvironmentVariables() + if err != nil { + ctx.LogError("Error processing environment variables for password complexity options: %s", err.Error()) + } + + return globalPasswordService +} + +func (s *PasswordService) GetOptions() *PasswordComplexityOptions { + return s.options +} + +func (s *PasswordService) SetOptions(options *PasswordComplexityOptions) { + s.options = options +} + +func (s *PasswordService) Hash(password string, salt string) (string, error) { + switch s.HashingAlgorithm { + case PasswordHashingAlgorithmBCrypt: + return s.hashBCrypt(password, salt) + case PasswordHashingAlgorithmSHA256: + return s.hashSHA256(password, salt) + default: + s.ctx.LogError("Unknown password hashing algorithm: %s", s.HashingAlgorithm) + return "", errors.Newf("Unknown password hashing algorithm: %s", s.HashingAlgorithm) + } +} + +func (s *PasswordService) Compare(password string, salt string, hashedPwd string) error { + switch s.HashingAlgorithm { + case PasswordHashingAlgorithmBCrypt: + return s.compareBCrypt(password, salt, hashedPwd) + case PasswordHashingAlgorithmSHA256: + return s.sha256Compare(password, salt, hashedPwd) + default: + s.ctx.LogError("Unknown password hashing algorithm: %s", s.HashingAlgorithm) + return errors.Newf("Unknown password hashing algorithm: %s", s.HashingAlgorithm) + } +} + +func (s *PasswordService) CheckPasswordComplexity(password string) (bool, *errors.Diagnostics) { + diagnostics := errors.NewDiagnostics() + + if len(password) < s.options.MinLength() { + diagnostics.AddError(errors.Newf("Password must be at least %d characters long", s.options.MinLength())) + } + if len(password) > s.options.MaxLength() { + diagnostics.AddError(errors.Newf("Password must be no more than %d characters long", s.options.MaxLength())) + } + if s.options.RequireLowercase() { + if !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") { + diagnostics.AddError(errors.Newf("Password must contain at least one lowercase letter")) + } + } + if s.options.RequireUppercase() { + if !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") { + diagnostics.AddError(errors.Newf("Password must contain at least one uppercase letter")) + } + } + if s.options.RequireNumbers() { + if !strings.ContainsAny(password, "0123456789") { + diagnostics.AddError(errors.Newf("Password must contain at least one number")) + } + } + if s.options.RequireSpecialCharacters() { + if !strings.ContainsAny(password, SPECIAL_CHARACTERS) { + diagnostics.AddError(errors.Newf("Password must contain at least one special character")) + } + } + + return !diagnostics.HasErrors(), diagnostics +} + +func (s *PasswordService) hashSHA256(password string, salt string) (string, error) { + saltedPwd, err := s.saltPassword(password, salt) + if err != nil { + return "", err + } + + hashedPassword := sha256.Sum256([]byte(saltedPwd)) + return hex.EncodeToString(hashedPassword[:]), nil +} + +func (s PasswordService) sha256Compare(password string, salt string, hashedPwd string) error { + saltedPwd, err := s.saltPassword(password, salt) + if err != nil { + return err + } + + hashedPassword := sha256.Sum256([]byte(saltedPwd)) + hashedPasswordString := hex.EncodeToString(hashedPassword[:]) + if hashedPasswordString != hashedPwd { + return errors.New("passwords do not match") + } + + return nil +} + +func (s *PasswordService) hashBCrypt(password string, salt string) (string, error) { + cost := bcrypt.DefaultCost + saltedPwd, err := s.saltPassword(password, salt) + if err != nil { + return "", err + } + + bytes, err := bcrypt.GenerateFromPassword([]byte(saltedPwd), cost) + if err != nil { + return "", err + } + return string(bytes), nil +} + +func (s *PasswordService) compareBCrypt(password string, salt string, hashedPwd string) error { + saltedPwd, err := s.saltPassword(password, salt) + if err != nil { + return err + } + + err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), saltedPwd) + if err != nil { + return err + } + return nil +} + +func (s *PasswordService) saltPassword(password string, salt string) ([]byte, error) { + // saltString := GenerateSalt(salt, cost) + inputBytes := []byte(password) + saltBytes := []byte(salt) + if len(inputBytes) > 40 { + return []byte{}, errors.New("password cannot be longer than 40 characters") + } + if len(saltBytes) > 32 { + saltBytes = saltBytes[:32] + } + + if !s.options.SaltPassword() { + return inputBytes, nil + } + + saltedPwd := []byte(password + string(saltBytes)) + + return saltedPwd, nil +} + +func (s *PasswordService) processEnvironmentVariables() error { + cfg := config.NewConfig() + if cfg.GetKey(constants.SECURITY_PASSWORD_MIN_PASSWORD_LENGTH_ENV_VAR) != "" { + minPasswordLength, err := strconv.Atoi(cfg.GetKey(constants.SECURITY_PASSWORD_MIN_PASSWORD_LENGTH_ENV_VAR)) + if err != nil { + return err + } + s.options.WithMinLength(minPasswordLength) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_MAX_PASSWORD_LENGTH_ENV_VAR) != "" { + maxPasswordLength, err := strconv.Atoi(cfg.GetKey(constants.SECURITY_PASSWORD_MAX_PASSWORD_LENGTH_ENV_VAR)) + if err != nil { + return err + } + s.options.WithMaxLength(maxPasswordLength) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_LOWERCASE_ENV_VAR) != "" { + requireLowercase, err := strconv.ParseBool(cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_LOWERCASE_ENV_VAR)) + if err != nil { + return err + } + s.options.WithRequireLowercase(requireLowercase) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_UPPERCASE_ENV_VAR) != "" { + requireUppercase, err := strconv.ParseBool(cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_UPPERCASE_ENV_VAR)) + if err != nil { + return err + } + s.options.WithRequireUppercase(requireUppercase) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_NUMBER_ENV_VAR) != "" { + requireNumber, err := strconv.ParseBool(cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_NUMBER_ENV_VAR)) + if err != nil { + return err + } + s.options.WithRequireNumbers(requireNumber) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR_ENV_VAR) != "" { + requireSpecialChar, err := strconv.ParseBool(cfg.GetKey(constants.SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR_ENV_VAR)) + if err != nil { + return err + } + s.options.WithRequireSpecialCharacters(requireSpecialChar) + } + if cfg.GetKey(constants.SECURITY_PASSWORD_SALT_PASSWORD_ENV_VAR) != "" { + saltPassword, err := strconv.ParseBool(cfg.GetKey(constants.SECURITY_PASSWORD_SALT_PASSWORD_ENV_VAR)) + if err != nil { + return err + } + s.options.WithSaltPassword(saltPassword) + } + + return nil +} diff --git a/src/security/password/main_test.go b/src/security/password/main_test.go new file mode 100644 index 00000000..9d2beecb --- /dev/null +++ b/src/security/password/main_test.go @@ -0,0 +1,520 @@ +package password + +import ( + "crypto/sha256" + "encoding/hex" + "os" + "testing" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/Parallels/pd-api-service/constants" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +func TestNoHashingAlgorithm(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = "test" + + _, err := svc.Hash(input, salt) + if err != nil { + assert.EqualError(t, err, "error: Unknown password hashing algorithm: test") + } +} + +func TestGetWithoutNew(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + globalPasswordService = nil + + svc := Get() + svc.SetOptions(&PasswordComplexityOptions{ + ctx: ctx, + minLength: 20, + maxLength: 40, + requireLowercase: true, + requireUppercase: true, + requireNumbers: true, + requireSpecialCharacters: true, + saltPassword: false, + }) + + testSvc := Get() + + assert.Equal(t, 20, svc.GetOptions().MinLength()) + assert.Equal(t, 40, svc.GetOptions().MaxLength()) + assert.Equal(t, true, svc.GetOptions().RequireLowercase()) + assert.Equal(t, true, svc.GetOptions().RequireUppercase()) + assert.Equal(t, true, svc.GetOptions().RequireNumbers()) + assert.Equal(t, true, svc.GetOptions().RequireSpecialCharacters()) + + assert.Equal(t, testSvc.GetOptions().MinLength(), svc.GetOptions().MinLength()) + assert.Equal(t, testSvc.GetOptions().MaxLength(), svc.GetOptions().MaxLength()) + assert.Equal(t, testSvc.GetOptions().RequireLowercase(), svc.GetOptions().RequireLowercase()) + assert.Equal(t, testSvc.GetOptions().RequireUppercase(), svc.GetOptions().RequireUppercase()) + assert.Equal(t, testSvc.GetOptions().RequireNumbers(), svc.GetOptions().RequireNumbers()) + assert.Equal(t, testSvc.GetOptions().RequireSpecialCharacters(), svc.GetOptions().RequireSpecialCharacters()) + assert.Equal(t, testSvc.GetOptions().SaltPassword(), svc.GetOptions().SaltPassword()) +} + +func TestSHA256hash(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmSHA256 + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + CompareHashedPwd := sha256.Sum256([]byte(input + salt)) + hashedPwdStr := hex.EncodeToString(CompareHashedPwd[:]) + + assert.Equal(t, hashedPwd, hashedPwdStr) +} + +func TestBcryptHash(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(input+salt)) + if err != nil { + t.Errorf("Hashed password does not match input: %v", err) + } +} + +func TestBcryptHashWithNoSalt(t *testing.T) { + input := "password" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + svc.GetOptions().WithSaltPassword(false) + + hashedPwd, err := svc.Hash(input, "") + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(input)) + if err != nil { + t.Errorf("Hashed password does not match input: %v", err) + } +} + +func TestSaltPasswordBiggerThan40Characters(t *testing.T) { + input := "password" + salt := "somesaltthatisbiggerthan40characters" + smallerSalt := "somesaltthatisbiggerthan40charac" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(input+smallerSalt)) + if err != nil { + t.Errorf("Hashed password does not match input: %v", err) + } +} + +func TestPasswordLongerThan40Characters(t *testing.T) { + input := "passwordpasswordpasswordpasswordpasswordpassword" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + + _, err := svc.Hash(input, salt) + if err != nil { + assert.EqualError(t, err, "error: password cannot be longer than 40 characters") + } +} + +func TestSetOptionsWithValuesGreaterThan40(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + + svc.GetOptions().WithMinLength(4) + svc.GetOptions().WithMaxLength(50) + + assert.Equal(t, 8, svc.GetOptions().MinLength()) + assert.Equal(t, 40, svc.GetOptions().MaxLength()) +} + +func TestCheckPasswordComplexity(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.GetOptions().WithMinLength(8) + svc.GetOptions().WithMaxLength(40) + svc.GetOptions().WithRequireLowercase(true) + svc.GetOptions().WithRequireUppercase(true) + svc.GetOptions().WithRequireNumbers(true) + svc.GetOptions().WithRequireSpecialCharacters(true) + + tests := []struct { + name string + password string + expected bool + }{ + { + name: "Valid password", + password: "Password123!", + expected: true, + }, + { + name: "Password too short", + password: "pass", + expected: false, + }, + { + name: "Password too long", + password: "passwordpasswordpasswordpasswordpasswordpasswordpasswordpassword", + expected: false, + }, + { + name: "Missing lowercase letter", + password: "PASSWORD123!", + expected: false, + }, + { + name: "Missing uppercase letter", + password: "password123!", + expected: false, + }, + { + name: "Missing number", + password: "Password!", + expected: false, + }, + { + name: "Missing special character", + password: "Password123", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + valid, diagnostics := svc.CheckPasswordComplexity(test.password) + assert.Equal(t, test.expected, valid) + if test.expected { + assert.Empty(t, diagnostics.Errors()) + } else { + assert.NotEmpty(t, diagnostics.Errors()) + } + }) + } +} + +func TestCompareBcrypt(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmBCrypt + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = svc.Compare(input, salt, hashedPwd) + if err != nil { + t.Errorf("Hashed password does not match input: %v", err) + } +} + +func TestCompareSHA256(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmSHA256 + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = svc.Compare(input, salt, hashedPwd) + if err != nil { + t.Errorf("Hashed password does not match input: %v", err) + } +} + +func TestCompareSHA256WithWrongPassword(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmSHA256 + + hashedPwd, err := svc.Hash(input, salt) + if err != nil { + t.Errorf("Error hashing password: %v", err) + } + + err = svc.Compare("wrongpassword", salt, hashedPwd) + assert.EqualError(t, err, "error: passwords do not match") +} + +func TestCompareWithNoHash(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = "test" + + err := svc.Compare(input, salt, "") + assert.EqualError(t, err, "error: Unknown password hashing algorithm: test") +} + +func TestHashWithBcryptSaltError(t *testing.T) { + input := "passwordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpassword" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + + _, err := svc.Hash(input, salt) + assert.EqualError(t, err, "error: password cannot be longer than 40 characters") +} + +func TestHashWithSHA256SaltError(t *testing.T) { + input := "passwordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpassword" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmSHA256 + + _, err := svc.Hash(input, salt) + assert.EqualError(t, err, "error: password cannot be longer than 40 characters") +} + +func TestCompareWithBcryptSaltError(t *testing.T) { + input := "password" + LongInput := "passwordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpassword" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + + _, err := svc.Hash(input, salt) + assert.Nil(t, err) + err = svc.Compare(LongInput, salt, "") + assert.EqualError(t, err, "error: password cannot be longer than 40 characters") +} + +func TestCompareWithBcryptError(t *testing.T) { + input := "password" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + + _, err := svc.Hash(input, salt) + assert.Nil(t, err) + err = svc.Compare("", salt, "") + assert.EqualError(t, err, "crypto/bcrypt: hashedSecret too short to be a bcrypted password") +} + +func TestCompareWithSHA256SaltError(t *testing.T) { + input := "password" + LongInput := "passwordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpasswordpassword" + salt := "somesalt" + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.HashingAlgorithm = PasswordHashingAlgorithmSHA256 + + _, err := svc.Hash(input, salt) + assert.Nil(t, err) + err = svc.Compare(LongInput, salt, "") + assert.EqualError(t, err, "error: password cannot be longer than 40 characters") +} +func TestPasswordService_processEnvironmentVariables(t *testing.T) { + svc := New(nil) + + t.Run("SetMinPasswordLength", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_MIN_PASSWORD_LENGTH_ENV_VAR, "8") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, 8, svc.GetOptions().MinLength()) + }) + + t.Run("SetMaxPasswordLength", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_MAX_PASSWORD_LENGTH_ENV_VAR, "40") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, 40, svc.GetOptions().MaxLength()) + }) + + t.Run("SetRequireLowercase", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_LOWERCASE_ENV_VAR, "true") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, true, svc.GetOptions().RequireLowercase()) + }) + + t.Run("SetRequireUppercase", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_UPPERCASE_ENV_VAR, "true") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, true, svc.GetOptions().RequireUppercase()) + }) + + t.Run("SetRequireNumber", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_NUMBER_ENV_VAR, "true") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, true, svc.GetOptions().RequireNumbers()) + }) + + t.Run("SetRequireSpecialChar", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR_ENV_VAR, "true") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, true, svc.GetOptions().RequireSpecialCharacters()) + }) + + t.Run("SetSaltPassword", func(t *testing.T) { + os.Clearenv() + err := os.Setenv(constants.SECURITY_PASSWORD_SALT_PASSWORD_ENV_VAR, "true") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.NoError(t, err) + + assert.Equal(t, true, svc.GetOptions().SaltPassword()) + }) +} + +func TestPasswordService_processEnvironmentVariablesError(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + t.Run("SetMinPasswordLengthError", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_MIN_PASSWORD_LENGTH_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, 12, svc.GetOptions().MinLength()) + }) + + t.Run("SetMaxPasswordLength", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_MAX_PASSWORD_LENGTH_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, 40, svc.GetOptions().MaxLength()) + }) + + t.Run("SetRequireLowercase", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_LOWERCASE_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, true, svc.GetOptions().RequireLowercase()) + }) + + t.Run("SetRequireUppercase", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_UPPERCASE_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, true, svc.GetOptions().RequireUppercase()) + }) + + t.Run("SetRequireNumber", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_NUMBER_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, true, svc.GetOptions().RequireNumbers()) + }) + + t.Run("SetRequireSpecialChar", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_REQUIRE_SPECIAL_CHAR_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, true, svc.GetOptions().RequireSpecialCharacters()) + }) + + t.Run("SetSaltPassword", func(t *testing.T) { + os.Clearenv() + svc := New(ctx) + + err := os.Setenv(constants.SECURITY_PASSWORD_SALT_PASSWORD_ENV_VAR, "A") + assert.NoError(t, err) + + err = svc.processEnvironmentVariables() + assert.Errorf(t, err, "strconv.Atoi: parsing \"A\": invalid syntax") + + assert.Equal(t, true, svc.GetOptions().SaltPassword()) + }) +} diff --git a/src/security/password/options.go b/src/security/password/options.go new file mode 100644 index 00000000..5385d83f --- /dev/null +++ b/src/security/password/options.go @@ -0,0 +1,99 @@ +package password + +import "github.com/Parallels/pd-api-service/basecontext" + +type PasswordComplexityOptions struct { + ctx basecontext.ApiContext + minLength int + maxLength int + requireLowercase bool + requireUppercase bool + requireNumbers bool + requireSpecialCharacters bool + saltPassword bool +} + +func NewPasswordComplexityOptions(ctx basecontext.ApiContext) *PasswordComplexityOptions { + return &PasswordComplexityOptions{ + ctx: ctx, + minLength: 12, + maxLength: 40, + requireLowercase: true, + requireUppercase: true, + requireNumbers: true, + requireSpecialCharacters: true, + saltPassword: true, + } +} + +func (p *PasswordComplexityOptions) WithMinLength(minLength int) *PasswordComplexityOptions { + if minLength < 8 { + p.ctx.LogWarn("Password complexity options MinLength cannot be less than 8. Setting to 8.") + minLength = 8 + } + p.minLength = minLength + return p +} + +func (p *PasswordComplexityOptions) WithMaxLength(maxLength int) *PasswordComplexityOptions { + if maxLength > 40 { + p.ctx.LogWarn("Password complexity options MaxLength cannot be greater than 40. Setting to 40.") + maxLength = 40 + } + + p.maxLength = maxLength + return p +} + +func (p *PasswordComplexityOptions) WithRequireLowercase(requireLowercase bool) *PasswordComplexityOptions { + p.requireLowercase = requireLowercase + return p +} + +func (p *PasswordComplexityOptions) WithRequireUppercase(requireUppercase bool) *PasswordComplexityOptions { + p.requireUppercase = requireUppercase + return p +} + +func (p *PasswordComplexityOptions) WithRequireNumbers(requireNumbers bool) *PasswordComplexityOptions { + p.requireNumbers = requireNumbers + return p +} + +func (p *PasswordComplexityOptions) WithRequireSpecialCharacters(requireSpecialCharacters bool) *PasswordComplexityOptions { + p.requireSpecialCharacters = requireSpecialCharacters + return p +} + +func (p *PasswordComplexityOptions) WithSaltPassword(saltPassword bool) *PasswordComplexityOptions { + p.saltPassword = saltPassword + return p +} + +func (p *PasswordComplexityOptions) MinLength() int { + return p.minLength +} + +func (p *PasswordComplexityOptions) MaxLength() int { + return p.maxLength +} + +func (p *PasswordComplexityOptions) RequireLowercase() bool { + return p.requireLowercase +} + +func (p *PasswordComplexityOptions) RequireUppercase() bool { + return p.requireUppercase +} + +func (p *PasswordComplexityOptions) RequireNumbers() bool { + return p.requireNumbers +} + +func (p *PasswordComplexityOptions) RequireSpecialCharacters() bool { + return p.requireSpecialCharacters +} + +func (p *PasswordComplexityOptions) SaltPassword() bool { + return p.saltPassword +} diff --git a/src/security/password/options_test.go b/src/security/password/options_test.go new file mode 100644 index 00000000..2def077c --- /dev/null +++ b/src/security/password/options_test.go @@ -0,0 +1,40 @@ +package password + +import ( + "testing" + + "github.com/Parallels/pd-api-service/basecontext" + "github.com/stretchr/testify/assert" +) + +func TestSetOptions(t *testing.T) { + ctx := basecontext.NewRootBaseContext() + svc := New(ctx) + svc.SetOptions(&PasswordComplexityOptions{ + ctx: ctx, + minLength: 20, + maxLength: 40, + requireLowercase: true, + requireUppercase: true, + requireNumbers: true, + requireSpecialCharacters: true, + saltPassword: false, + }) + + testSvc := Get() + + assert.Equal(t, 20, svc.GetOptions().MinLength()) + assert.Equal(t, 40, svc.GetOptions().MaxLength()) + assert.Equal(t, true, svc.GetOptions().RequireLowercase()) + assert.Equal(t, true, svc.GetOptions().RequireUppercase()) + assert.Equal(t, true, svc.GetOptions().RequireNumbers()) + assert.Equal(t, true, svc.GetOptions().RequireSpecialCharacters()) + + assert.Equal(t, testSvc.GetOptions().MinLength(), svc.GetOptions().MinLength()) + assert.Equal(t, testSvc.GetOptions().MaxLength(), svc.GetOptions().MaxLength()) + assert.Equal(t, testSvc.GetOptions().RequireLowercase(), svc.GetOptions().RequireLowercase()) + assert.Equal(t, testSvc.GetOptions().RequireUppercase(), svc.GetOptions().RequireUppercase()) + assert.Equal(t, testSvc.GetOptions().RequireNumbers(), svc.GetOptions().RequireNumbers()) + assert.Equal(t, testSvc.GetOptions().RequireSpecialCharacters(), svc.GetOptions().RequireSpecialCharacters()) + assert.Equal(t, testSvc.GetOptions().SaltPassword(), svc.GetOptions().SaltPassword()) +} diff --git a/src/serviceprovider/git/main.go b/src/serviceprovider/git/main.go index 61dc02c7..7a547867 100644 --- a/src/serviceprovider/git/main.go +++ b/src/serviceprovider/git/main.go @@ -198,7 +198,7 @@ func (s *GitService) Clone(ctx basecontext.ApiContext, repoURL string, owner str if owner == "" || owner == "root" { path = filepath.Join("/tmp", localPath) } else { - home, err := system.Get(ctx).GetUserHome(ctx, owner) + home, err := system.Get().GetUserHome(ctx, owner) if err != nil { return "", err } diff --git a/src/serviceprovider/main.go b/src/serviceprovider/main.go index b9b64b87..49b34bab 100644 --- a/src/serviceprovider/main.go +++ b/src/serviceprovider/main.go @@ -103,16 +103,7 @@ func InitCatalogServices(ctx basecontext.ApiContext) { } globalProvider.HardwareId = hid - secretKey := strings.ReplaceAll(key, "-", "") - secretHid := strings.ReplaceAll(hid, "-", "") - if len(secretKey) > 12 { - secretKey = secretKey[:12] - } - if len(secretHid) > 12 { - secretHid = secretHid[:12] - } - - globalProvider.HardwareSecret = base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s%s", secretKey, secretHid))) + globalProvider.HardwareSecret = getHardwareSecret(key, hid) if systemHardwareInfo, err := globalProvider.System.GetHardwareInfo(ctx); err == nil { globalProvider.SystemHardwareInfo = systemHardwareInfo } @@ -204,7 +195,7 @@ func InitServices(ctx basecontext.ApiContext) { } globalProvider.HardwareId = hid - globalProvider.HardwareSecret = base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", key, hid))) + globalProvider.HardwareSecret = getHardwareSecret(key, hid) if systemHardwareInfo, err := globalProvider.System.GetHardwareInfo(ctx); err == nil { globalProvider.SystemHardwareInfo = systemHardwareInfo } @@ -290,3 +281,16 @@ func GetService[T *any](name string) (T, error) { return nil, errors.New("Service not found") } + +func getHardwareSecret(key, hid string) string { + secretKey := strings.ReplaceAll(key, "-", "") + secretHid := strings.ReplaceAll(hid, "-", "") + if len(secretKey) > 12 { + secretKey = secretKey[:12] + } + if len(secretHid) > 12 { + secretHid = secretHid[:12] + } + + return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s%s", secretKey, secretHid))) +} diff --git a/src/serviceprovider/parallelsdesktop/main.go b/src/serviceprovider/parallelsdesktop/main.go index dc54c7bd..b7851466 100644 --- a/src/serviceprovider/parallelsdesktop/main.go +++ b/src/serviceprovider/parallelsdesktop/main.go @@ -256,7 +256,7 @@ func (s *ParallelsService) IsLicensed() bool { func (s *ParallelsService) GetVms(ctx basecontext.ApiContext, filter string) ([]models.ParallelsVM, error) { var systemMachines []models.ParallelsVM - users, err := system.Get(ctx).GetSystemUsers(ctx) + users, err := system.Get().GetSystemUsers(ctx) if err != nil { return nil, err } @@ -841,7 +841,7 @@ func (s *ParallelsService) CreatePackerTemplateVm(ctx basecontext.ApiContext, te ctx.LogInfo("Built packer machine") - users, err := system.Get(ctx).GetSystemUsers(ctx) + users, err := system.Get().GetSystemUsers(ctx) if err != nil { if cleanError := helpers.RemoveFolder(repoPath); cleanError != nil { ctx.LogError("Error removing folder %s: %s", repoPath, cleanError.Error()) @@ -1429,7 +1429,7 @@ func (s *ParallelsService) GetHardwareUsage(ctx basecontext.ApiContext) (*models } } - systemSrv := system.Get(ctx) + systemSrv := system.Get() systemInfo, err := systemSrv.GetHardwareInfo(ctx) if err != nil { return nil, err diff --git a/src/serviceprovider/system/main.go b/src/serviceprovider/system/main.go index 56c109dd..fed84bd8 100644 --- a/src/serviceprovider/system/main.go +++ b/src/serviceprovider/system/main.go @@ -24,10 +24,13 @@ type SystemService struct { dependencies []interfaces.Service } -func Get(ctx basecontext.ApiContext) *SystemService { +func Get() *SystemService { if globalSystemService != nil { return globalSystemService } + + ctx := basecontext.NewBaseContext() + return New(ctx) } diff --git a/src/serviceprovider/vagrant/main.go b/src/serviceprovider/vagrant/main.go index bff7a444..81a9a756 100644 --- a/src/serviceprovider/vagrant/main.go +++ b/src/serviceprovider/vagrant/main.go @@ -297,7 +297,7 @@ func (s *VagrantService) updateVagrantFile(ctx basecontext.ApiContext, filePath } func (s *VagrantService) getVagrantFolderPath(ctx basecontext.ApiContext, request models.CreateVagrantMachineRequest) (string, error) { - system := system.Get(s.ctx) + system := system.Get() rootDir, err := system.GetUserHome(ctx, request.Owner) if err != nil { return "", err diff --git a/src/startup/main.go b/src/startup/main.go index 0e4020c8..d9339dcd 100644 --- a/src/startup/main.go +++ b/src/startup/main.go @@ -9,6 +9,9 @@ import ( "github.com/Parallels/pd-api-service/errors" "github.com/Parallels/pd-api-service/helpers" "github.com/Parallels/pd-api-service/orchestrator" + bruteforceguard "github.com/Parallels/pd-api-service/security/brute_force_guard" + "github.com/Parallels/pd-api-service/security/jwt" + "github.com/Parallels/pd-api-service/security/password" "github.com/Parallels/pd-api-service/serviceprovider" "github.com/Parallels/pd-api-service/serviceprovider/system" ) @@ -17,6 +20,14 @@ const ( ORCHESTRATOR_KEY_NAME = "orchestrator_key" ) +func Init() { + ctx := basecontext.NewRootBaseContext() + + password.New(ctx) + jwt.New(ctx) + bruteforceguard.New(ctx) +} + func Start() { config := config.NewConfig() config.GetLogLevel()