diff --git a/pkg/controller/key/key.go b/pkg/controller/key/key.go index 02c5365..5a8d214 100644 --- a/pkg/controller/key/key.go +++ b/pkg/controller/key/key.go @@ -17,11 +17,12 @@ type Controller struct { Validator *validator.Validate Logger *utility.Logger ExtReq request.ExternalRequest + Key key.KeyService } func (base *Controller) CreateKey(c *gin.Context) { - respData, code, err := key.CreateKey(base.Db.Postgresql.DB(), c) + respData, code, err := base.Key.CreateKey(c) if err != nil { rd := utility.BuildErrorResponse(http.StatusBadRequest, "error", err.Error(), err, nil) c.JSON(http.StatusBadRequest, rd) @@ -37,7 +38,7 @@ func (base *Controller) CreateKey(c *gin.Context) { func (base *Controller) VerifyKey(c *gin.Context) { req := models.VerifyKeyRequestModel{} - respData, code, err := key.VerifyKey(req, base.Db.Postgresql.DB(), c) + respData, code, err := base.Key.VerifyKey(req, c) if err != nil { rd := utility.BuildErrorResponse(http.StatusBadRequest, "error", err.Error(), err, nil) c.JSON(http.StatusBadRequest, rd) diff --git a/pkg/router/auth.go b/pkg/router/auth.go index 0b7f6ee..dcd3a5c 100644 --- a/pkg/router/auth.go +++ b/pkg/router/auth.go @@ -12,13 +12,15 @@ import ( "github.com/hngprojects/hng_boilerplate_golang_web/pkg/controller/key" "github.com/hngprojects/hng_boilerplate_golang_web/pkg/middleware" "github.com/hngprojects/hng_boilerplate_golang_web/pkg/repository/storage" + keyService "github.com/hngprojects/hng_boilerplate_golang_web/services/key" "github.com/hngprojects/hng_boilerplate_golang_web/utility" ) func Auth(r *gin.Engine, ApiVersion string, validator *validator.Validate, db *storage.Database, logger *utility.Logger) *gin.Engine { extReq := request.ExternalRequest{Logger: logger, Test: false} auth := auth.Controller{Db: db, Validator: validator, Logger: logger, ExtReq: extReq} - key := key.Controller{Db: db, Validator: validator, Logger: logger, ExtReq: extReq} + newKey := keyService.NewKeyService(db.Postgresql.DB()) + key := key.Controller{Db: db, Validator: validator, Logger: logger, ExtReq: extReq, Key: newKey} authUrl := r.Group(fmt.Sprintf("%v/auth", ApiVersion)) { diff --git a/services/key/key.go b/services/key/key.go index da5b328..ff8a3d2 100644 --- a/services/key/key.go +++ b/services/key/key.go @@ -14,17 +14,34 @@ import ( "gorm.io/gorm" ) -func CreateKey(db *gorm.DB, c *gin.Context) (gin.H, int, error) { +// KeyService defines the interface for key-related operations +type KeyService interface { + CreateKey(c *gin.Context) (gin.H, int, error) + VerifyKey(req models.VerifyKeyRequestModel, c *gin.Context) (gin.H, int, error) +} + +// keyServiceImpl is the concrete implementation of KeyService +type keyServiceImpl struct { + db *gorm.DB +} + +// NewKeyService initializes a new KeyService instance +func NewKeyService(db *gorm.DB) KeyService { + return &keyServiceImpl{db: db} +} + +// CreateKey generates a new OTP key for a user +func (s *keyServiceImpl) CreateKey(c *gin.Context) (gin.H, int, error) { userID, _ := middleware.GetIdFromToken(c) log.Print(userID) if userID == "" { - return nil, http.StatusBadRequest, errors.New("User is not authenticated") + return nil, http.StatusBadRequest, errors.New("user is not authenticated") } var existingKey models.Key - if err := db.Where("user_id = ?", userID).First(&existingKey).Error; err == nil { - return nil, http.StatusConflict, errors.New("Key for this user already exists") + if err := s.db.Where("user_id = ?", userID).First(&existingKey).Error; err == nil { + return nil, http.StatusConflict, errors.New("key for this user already exists") } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, http.StatusInternalServerError, err } @@ -38,9 +55,8 @@ func CreateKey(db *gorm.DB, c *gin.Context) (gin.H, int, error) { } var user models.User - if err := db.Where("id = ?", userID).First(&user).Error; err != nil { - db.Rollback() - return nil, http.StatusNotFound, errors.New("User not found") + if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { + return nil, http.StatusNotFound, errors.New("user not found") } keyModel := models.Key{ @@ -49,13 +65,10 @@ func CreateKey(db *gorm.DB, c *gin.Context) (gin.H, int, error) { Key: secret.Secret(), } - if err := db.Create(&keyModel).Error; err != nil { - db.Rollback() + if err := s.db.Create(&keyModel).Error; err != nil { return nil, http.StatusInternalServerError, err } - db.Commit() - png, err := qrcode.Encode(secret.URL(), qrcode.Medium, 256) if err != nil { return nil, http.StatusInternalServerError, err @@ -67,22 +80,24 @@ func CreateKey(db *gorm.DB, c *gin.Context) (gin.H, int, error) { }, http.StatusCreated, nil } -func VerifyKey(req models.VerifyKeyRequestModel, db *gorm.DB, c *gin.Context) (gin.H, int, error) { +// VerifyKey checks if a given OTP key is valid +func (s *keyServiceImpl) VerifyKey(req models.VerifyKeyRequestModel, c *gin.Context) (gin.H, int, error) { userID, _ := middleware.GetIdFromToken(c) key := req.Key if key == "" || userID == "" { - return nil, http.StatusBadRequest, errors.New("Key and User ID are required") + return nil, http.StatusBadRequest, errors.New("key and user ID are required") } + var keyModel models.Key - if err := db.Where("user_id = ?", userID).First(&keyModel).Error; err != nil { - return nil, http.StatusNotFound, errors.New("Key not found") + if err := s.db.Where("user_id = ?", userID).First(&keyModel).Error; err != nil { + return nil, http.StatusNotFound, errors.New("key not found") } if !totp.Validate(key, keyModel.Key) { - return nil, http.StatusUnauthorized, errors.New("Invalid key") + return nil, http.StatusUnauthorized, errors.New("invalid key") } return gin.H{ - "message": "Key verified successfully", + "message": "key verified successfully", }, http.StatusOK, nil }