diff --git a/cmd/root.go b/cmd/root.go index 43528a8..f23b624 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,6 +14,7 @@ import ( "github.com/marcopollivier/techagenda/lib/ssr" "github.com/marcopollivier/techagenda/pkg/event" "github.com/marcopollivier/techagenda/pkg/lending" + "github.com/marcopollivier/techagenda/pkg/oauth" "github.com/marcopollivier/techagenda/pkg/static" "github.com/marcopollivier/techagenda/pkg/user" ) @@ -27,12 +28,13 @@ var rootCmd = &cobra.Command{ fx.New( fx.Provide(database.NewDB), fx.Provide(server.NewHTTPServer), + oauth.Module(), ssr.Module(), - static.Module(), user.Module(), event.Module(), - // attendee.Module(), + lending.Module(), + static.Module(), ).Run() }, } diff --git a/migrations/20240113234311_create_user_table.go b/migrations/20240113234311_create_user_table.go index 15929ad..2952a32 100644 --- a/migrations/20240113234311_create_user_table.go +++ b/migrations/20240113234311_create_user_table.go @@ -27,10 +27,10 @@ func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { created_at TIMESTAMP DEFAULT now(), updated_at TIMESTAMP DEFAULT now(), deleted_at TIMESTAMP, - + UNIQUE(email) ); - + CREATE INDEX idx_users_role on users (role); `); err != nil { return err @@ -39,16 +39,16 @@ func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { // OAuth table if _, err := tx.ExecContext(ctx, ` CREATE TYPE provider AS ENUM ('github'); - + CREATE TABLE IF NOT EXISTS oauths ( id BIGSERIAL PRIMARY KEY, - user_id BIGINT NOT NULL, + user_id BIGINT NOT NULL REFERENCES users(id), provider PROVIDER NOT NULL, identifier TEXT NOT NULL, created_at TIMESTAMP DEFAULT now(), updated_at TIMESTAMP DEFAULT now(), deleted_at TIMESTAMP, - + CONSTRAINT fk_oauths_user_id FOREIGN KEY (user_id) REFERENCES users(id), UNIQUE(user_id, provider, identifier) ); diff --git a/pkg/event/model_enum.go b/pkg/event/model_enum.go index d1e2312..67c4c36 100644 --- a/pkg/event/model_enum.go +++ b/pkg/event/model_enum.go @@ -1,8 +1,8 @@ // Code generated by go-enum DO NOT EDIT. -// Version: 0.6.0 -// Revision: 919e61c0174b91303753ee3898569a01abb32c97 -// Build Date: 2023-12-18T15:54:43Z -// Built By: goreleaser +// Version: +// Revision: +// Build Date: +// Built By: package event diff --git a/pkg/lending/handler.go b/pkg/lending/handler.go index ec0bd14..590eed1 100644 --- a/pkg/lending/handler.go +++ b/pkg/lending/handler.go @@ -9,7 +9,7 @@ import ( "github.com/marcopollivier/techagenda/lib/ssr" "github.com/marcopollivier/techagenda/pkg/event" - "github.com/marcopollivier/techagenda/pkg/user" + "github.com/marcopollivier/techagenda/pkg/oauth" ) type QueryParams struct { @@ -45,7 +45,7 @@ func NewLendingHandler(server *echo.Echo, eventService event.Service, engine *ss }, Props: &ssr.Props{ Events: events, - User: user.GetUserFromCtx(ctx), + User: oauth.GetUserFromCtx(ctx), }, }) return c.HTML(http.StatusOK, string(page)) diff --git a/pkg/oauth/fx.go b/pkg/oauth/fx.go new file mode 100644 index 0000000..f64473b --- /dev/null +++ b/pkg/oauth/fx.go @@ -0,0 +1,11 @@ +package oauth + +import "go.uber.org/fx" + +func Module() fx.Option { + return fx.Module("oauth", + fx.Provide(NewOAuthService), + fx.Provide(NewOAuthHandler), + fx.Invoke(SetOAuthAPIRoutes), + ) +} diff --git a/pkg/user/handler_oauth.go b/pkg/oauth/handler.go similarity index 86% rename from pkg/user/handler_oauth.go rename to pkg/oauth/handler.go index 9d89969..9c3dab6 100644 --- a/pkg/user/handler_oauth.go +++ b/pkg/oauth/handler.go @@ -1,4 +1,4 @@ -package user +package oauth import ( "fmt" @@ -8,22 +8,31 @@ import ( "github.com/labstack/echo/v4" "github.com/marcopollivier/techagenda/lib/server" "github.com/marcopollivier/techagenda/lib/session" + "github.com/marcopollivier/techagenda/pkg/user" "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/samber/lo" ) -func (h *UserHandler) AuthLogin(c echo.Context) (err error) { +type OAuthHandler struct { + service Service +} + +func NewOAuthHandler(service Service) *OAuthHandler { + return &OAuthHandler{service: service} +} + +func (h *OAuthHandler) AuthLogin(c echo.Context) (err error) { var ( ctx = c.Request().Context() res = c.Response() req = c.Request() authUser goth.User - userData User + userData user.User token string ) - if _, ok := c.Request().Context().Value(MiddlewareUserKey).(User); ok { + if _, ok := c.Request().Context().Value(MiddlewareUserKey).(user.User); ok { res.Header().Set("Location", getReferer(req)) res.WriteHeader(http.StatusTemporaryRedirect) return @@ -55,7 +64,7 @@ func (h *UserHandler) AuthLogin(c echo.Context) (err error) { return } -func (h *UserHandler) AuthLogout(c echo.Context) (err error) { +func (h *OAuthHandler) AuthLogout(c echo.Context) (err error) { var ( res = c.Response() req = c.Request() @@ -73,18 +82,18 @@ func (h *UserHandler) AuthLogout(c echo.Context) (err error) { return } -func (h *UserHandler) AuthCallback(c echo.Context) (err error) { +func (h *OAuthHandler) AuthCallback(c echo.Context) (err error) { var ( ctx = c.Request().Context() res = c.Response() req = c.Request() authUser goth.User - userData User + userData user.User token string provider string ) - if _, ok := c.Request().Context().Value(MiddlewareUserKey).(User); ok { + if _, ok := c.Request().Context().Value(MiddlewareUserKey).(user.User); ok { res.Header().Set("Location", getReferer(req)) res.WriteHeader(http.StatusTemporaryRedirect) return diff --git a/pkg/user/middleware.go b/pkg/oauth/middleware.go similarity index 80% rename from pkg/user/middleware.go rename to pkg/oauth/middleware.go index 70a4bd6..20da0de 100644 --- a/pkg/user/middleware.go +++ b/pkg/oauth/middleware.go @@ -1,4 +1,4 @@ -package user +package oauth import ( "context" @@ -6,6 +6,7 @@ import ( "github.com/labstack/echo/v4" "github.com/marcopollivier/techagenda/lib/session" + "github.com/marcopollivier/techagenda/pkg/user" ) type MiddlewareCtxKey string @@ -19,7 +20,7 @@ func AuthMiddleware(service Service) echo.MiddlewareFunc { return func(c echo.Context) (err error) { var ( ctx = c.Request().Context() - user User + user user.User req = c.Request() res = c.Response() userSession session.UserSession @@ -38,9 +39,9 @@ func AuthMiddleware(service Service) echo.MiddlewareFunc { } } -func GetUserFromCtx(ctx context.Context) *User { - var userPtr *User - if userData, ok := ctx.Value(MiddlewareUserKey).(User); ok { +func GetUserFromCtx(ctx context.Context) *user.User { + var userPtr *user.User + if userData, ok := ctx.Value(MiddlewareUserKey).(user.User); ok { userPtr = &userData } return userPtr diff --git a/pkg/oauth/oauth.go b/pkg/oauth/oauth.go new file mode 100644 index 0000000..3386e8b --- /dev/null +++ b/pkg/oauth/oauth.go @@ -0,0 +1,19 @@ +package oauth + +import ( + "gorm.io/gorm" +) + +//go:generate go-enum --marshal --sql -f oauth.go + +type OAuth struct { + gorm.Model + UserID uint + Provider Provider + Identifier string +} + +func (OAuth) TableName() string { return "oauths" } + +// ENUM(github) +type Provider int diff --git a/pkg/oauth/oauth_enum.go b/pkg/oauth/oauth_enum.go new file mode 100644 index 0000000..29c671f --- /dev/null +++ b/pkg/oauth/oauth_enum.go @@ -0,0 +1,142 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: +// Revision: +// Build Date: +// Built By: + +package oauth + +import ( + "database/sql/driver" + "errors" + "fmt" +) + +const ( + // ProviderGithub is a Provider of type Github. + ProviderGithub Provider = iota +) + +var ErrInvalidProvider = errors.New("not a valid Provider") + +const _ProviderName = "github" + +var _ProviderMap = map[Provider]string{ + ProviderGithub: _ProviderName[0:6], +} + +// String implements the Stringer interface. +func (x Provider) String() string { + if str, ok := _ProviderMap[x]; ok { + return str + } + return fmt.Sprintf("Provider(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x Provider) IsValid() bool { + _, ok := _ProviderMap[x] + return ok +} + +var _ProviderValue = map[string]Provider{ + _ProviderName[0:6]: ProviderGithub, +} + +// ParseProvider attempts to convert a string to a Provider. +func ParseProvider(name string) (Provider, error) { + if x, ok := _ProviderValue[name]; ok { + return x, nil + } + return Provider(0), fmt.Errorf("%s is %w", name, ErrInvalidProvider) +} + +// MarshalText implements the text marshaller method. +func (x Provider) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *Provider) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseProvider(name) + if err != nil { + return err + } + *x = tmp + return nil +} + +var errProviderNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *Provider) Scan(value interface{}) (err error) { + if value == nil { + *x = Provider(0) + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x = Provider(v) + case string: + *x, err = ParseProvider(v) + case []byte: + *x, err = ParseProvider(string(v)) + case Provider: + *x = v + case int: + *x = Provider(v) + case *Provider: + if v == nil { + return errProviderNilPtr + } + *x = *v + case uint: + *x = Provider(v) + case uint64: + *x = Provider(v) + case *int: + if v == nil { + return errProviderNilPtr + } + *x = Provider(*v) + case *int64: + if v == nil { + return errProviderNilPtr + } + *x = Provider(*v) + case float64: // json marshals everything as a float64 if it's a number + *x = Provider(v) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil { + return errProviderNilPtr + } + *x = Provider(*v) + case *uint: + if v == nil { + return errProviderNilPtr + } + *x = Provider(*v) + case *uint64: + if v == nil { + return errProviderNilPtr + } + *x = Provider(*v) + case *string: + if v == nil { + return errProviderNilPtr + } + *x, err = ParseProvider(*v) + } + + return +} + +// Value implements the driver Valuer interface. +func (x Provider) Value() (driver.Value, error) { + return x.String(), nil +} diff --git a/pkg/user/oauth_providers.go b/pkg/oauth/providers.go similarity index 97% rename from pkg/user/oauth_providers.go rename to pkg/oauth/providers.go index 23339b6..879a99c 100644 --- a/pkg/user/oauth_providers.go +++ b/pkg/oauth/providers.go @@ -1,4 +1,4 @@ -package user +package oauth import ( "fmt" diff --git a/pkg/oauth/router.go b/pkg/oauth/router.go new file mode 100644 index 0000000..9d59914 --- /dev/null +++ b/pkg/oauth/router.go @@ -0,0 +1,13 @@ +package oauth + +import "github.com/labstack/echo/v4" + +func SetOAuthAPIRoutes(server *echo.Echo, handler *OAuthHandler) { + registerProviders() + server.Use(AuthMiddleware(handler.service)) + auth := server.Group("/auth") + + auth.GET("", handler.AuthLogin) + auth.GET("/logout", handler.AuthLogout) + auth.GET("/callback", handler.AuthCallback) +} diff --git a/pkg/oauth/service.go b/pkg/oauth/service.go new file mode 100644 index 0000000..8077961 --- /dev/null +++ b/pkg/oauth/service.go @@ -0,0 +1,95 @@ +package oauth + +import ( + "context" + "fmt" + "log/slog" + + "github.com/marcopollivier/techagenda/pkg/user" + "github.com/markbates/goth" + "gorm.io/gorm" +) + +type Service interface { + Auth(ctx context.Context, oauthUser goth.User) (user user.User, err error) +} + +type OAuthService struct { + db *gorm.DB + userService user.Service +} + +func NewOAuthService(db *gorm.DB, userService user.Service) Service { + return &OAuthService{ + db: db, + userService: userService, + } +} + +func (s *OAuthService) Auth(ctx context.Context, oauthUser goth.User) (authUser user.User, err error) { + var ( + oauth OAuth + provider Provider + ) + + if err = s.db.WithContext(ctx). + Where("provider = ?", oauthUser.Provider). + Where("identifier = ?", oauthUser.UserID). + First(&oauth).Error; err != nil && err != gorm.ErrRecordNotFound { + slog.ErrorContext(ctx, "Unexpected error searching for oauth link", "provider", oauthUser.Provider, "error", err.Error()) + return authUser, err + } + + // If the provider and id does not match to any one already on the database, we need to link with an user if the email already exists, if not we need to create a new user and link it. + if err == gorm.ErrRecordNotFound { + slog.WarnContext(ctx, fmt.Sprintf("We didn't found a oauth link for email %s and provider %s", oauthUser.Email, oauthUser.Provider)) + if authUser, err = s.userService.GetByEmail(ctx, oauthUser.Email); err != nil { + if err != gorm.ErrRecordNotFound { + slog.ErrorContext(ctx, "Unexpected error searching for user", "error", err.Error()) + return authUser, err + } + + slog.WarnContext(ctx, "No user found to this oauth link, creating a new one") + authUser = user.User{ + Email: oauthUser.Email, + Name: oauthUser.Name, + Avatar: oauthUser.AvatarURL, + Bio: oauthUser.Description, + } + if authUser, err = s.userService.Create(ctx, authUser); err != nil { + return authUser, err + } + } + + if provider, err = ParseProvider(oauthUser.Provider); err != nil { + slog.ErrorContext(ctx, fmt.Sprintf("Unexpected provider %s", oauthUser.Provider), "error", err.Error()) + return authUser, err + } + + slog.InfoContext(ctx, fmt.Sprintf("Linking user %d to oauth provider %s of identifier %s", authUser.ID, oauthUser.Provider, oauthUser.UserID)) + oauth = OAuth{ + UserID: authUser.ID, + Provider: provider, + Identifier: oauthUser.UserID, + } + if err = s.db.WithContext(ctx).Create(&oauth).Error; err != nil { + slog.ErrorContext(ctx, "Fail to create link of oauth user", "user", authUser.ID, "error", err.Error()) + return authUser, err + } + return authUser, err + } + + // If oauth is linked with a user just return the user + if authUser, err = s.userService.Get(ctx, oauth.UserID); err != nil { + return + } + + go func() { + if authUser.Avatar != oauthUser.AvatarURL { + if _, err := s.userService.UpdateAvatar(context.Background(), authUser.ID, oauthUser.AvatarURL); err != nil { + slog.Error("Unable to update users avatar", "user_id", authUser.ID, "error", err.Error()) + } + } + }() + return +} diff --git a/pkg/user/models_enum.go b/pkg/user/models_enum.go deleted file mode 100644 index 9f61c55..0000000 --- a/pkg/user/models_enum.go +++ /dev/null @@ -1,279 +0,0 @@ -// Code generated by go-enum DO NOT EDIT. -// Version: 0.6.0 -// Revision: 919e61c0174b91303753ee3898569a01abb32c97 -// Build Date: 2023-12-18T15:54:43Z -// Built By: goreleaser - -package user - -import ( - "database/sql/driver" - "errors" - "fmt" -) - -const ( - // ProviderGithub is a Provider of type Github. - ProviderGithub Provider = iota -) - -var ErrInvalidProvider = errors.New("not a valid Provider") - -const _ProviderName = "github" - -var _ProviderMap = map[Provider]string{ - ProviderGithub: _ProviderName[0:6], -} - -// String implements the Stringer interface. -func (x Provider) String() string { - if str, ok := _ProviderMap[x]; ok { - return str - } - return fmt.Sprintf("Provider(%d)", x) -} - -// IsValid provides a quick way to determine if the typed value is -// part of the allowed enumerated values -func (x Provider) IsValid() bool { - _, ok := _ProviderMap[x] - return ok -} - -var _ProviderValue = map[string]Provider{ - _ProviderName[0:6]: ProviderGithub, -} - -// ParseProvider attempts to convert a string to a Provider. -func ParseProvider(name string) (Provider, error) { - if x, ok := _ProviderValue[name]; ok { - return x, nil - } - return Provider(0), fmt.Errorf("%s is %w", name, ErrInvalidProvider) -} - -// MarshalText implements the text marshaller method. -func (x Provider) MarshalText() ([]byte, error) { - return []byte(x.String()), nil -} - -// UnmarshalText implements the text unmarshaller method. -func (x *Provider) UnmarshalText(text []byte) error { - name := string(text) - tmp, err := ParseProvider(name) - if err != nil { - return err - } - *x = tmp - return nil -} - -var errProviderNilPtr = errors.New("value pointer is nil") // one per type for package clashes - -// Scan implements the Scanner interface. -func (x *Provider) Scan(value interface{}) (err error) { - if value == nil { - *x = Provider(0) - return - } - - // A wider range of scannable types. - // driver.Value values at the top of the list for expediency - switch v := value.(type) { - case int64: - *x = Provider(v) - case string: - *x, err = ParseProvider(v) - case []byte: - *x, err = ParseProvider(string(v)) - case Provider: - *x = v - case int: - *x = Provider(v) - case *Provider: - if v == nil { - return errProviderNilPtr - } - *x = *v - case uint: - *x = Provider(v) - case uint64: - *x = Provider(v) - case *int: - if v == nil { - return errProviderNilPtr - } - *x = Provider(*v) - case *int64: - if v == nil { - return errProviderNilPtr - } - *x = Provider(*v) - case float64: // json marshals everything as a float64 if it's a number - *x = Provider(v) - case *float64: // json marshals everything as a float64 if it's a number - if v == nil { - return errProviderNilPtr - } - *x = Provider(*v) - case *uint: - if v == nil { - return errProviderNilPtr - } - *x = Provider(*v) - case *uint64: - if v == nil { - return errProviderNilPtr - } - *x = Provider(*v) - case *string: - if v == nil { - return errProviderNilPtr - } - *x, err = ParseProvider(*v) - } - - return -} - -// Value implements the driver Valuer interface. -func (x Provider) Value() (driver.Value, error) { - return x.String(), nil -} - -const ( - // RoleUser is a Role of type User. - RoleUser Role = iota - // RoleMod is a Role of type Mod. - RoleMod - // RoleAdmin is a Role of type Admin. - RoleAdmin -) - -var ErrInvalidRole = errors.New("not a valid Role") - -const _RoleName = "usermodadmin" - -var _RoleMap = map[Role]string{ - RoleUser: _RoleName[0:4], - RoleMod: _RoleName[4:7], - RoleAdmin: _RoleName[7:12], -} - -// String implements the Stringer interface. -func (x Role) String() string { - if str, ok := _RoleMap[x]; ok { - return str - } - return fmt.Sprintf("Role(%d)", x) -} - -// IsValid provides a quick way to determine if the typed value is -// part of the allowed enumerated values -func (x Role) IsValid() bool { - _, ok := _RoleMap[x] - return ok -} - -var _RoleValue = map[string]Role{ - _RoleName[0:4]: RoleUser, - _RoleName[4:7]: RoleMod, - _RoleName[7:12]: RoleAdmin, -} - -// ParseRole attempts to convert a string to a Role. -func ParseRole(name string) (Role, error) { - if x, ok := _RoleValue[name]; ok { - return x, nil - } - return Role(0), fmt.Errorf("%s is %w", name, ErrInvalidRole) -} - -// MarshalText implements the text marshaller method. -func (x Role) MarshalText() ([]byte, error) { - return []byte(x.String()), nil -} - -// UnmarshalText implements the text unmarshaller method. -func (x *Role) UnmarshalText(text []byte) error { - name := string(text) - tmp, err := ParseRole(name) - if err != nil { - return err - } - *x = tmp - return nil -} - -var errRoleNilPtr = errors.New("value pointer is nil") // one per type for package clashes - -// Scan implements the Scanner interface. -func (x *Role) Scan(value interface{}) (err error) { - if value == nil { - *x = Role(0) - return - } - - // A wider range of scannable types. - // driver.Value values at the top of the list for expediency - switch v := value.(type) { - case int64: - *x = Role(v) - case string: - *x, err = ParseRole(v) - case []byte: - *x, err = ParseRole(string(v)) - case Role: - *x = v - case int: - *x = Role(v) - case *Role: - if v == nil { - return errRoleNilPtr - } - *x = *v - case uint: - *x = Role(v) - case uint64: - *x = Role(v) - case *int: - if v == nil { - return errRoleNilPtr - } - *x = Role(*v) - case *int64: - if v == nil { - return errRoleNilPtr - } - *x = Role(*v) - case float64: // json marshals everything as a float64 if it's a number - *x = Role(v) - case *float64: // json marshals everything as a float64 if it's a number - if v == nil { - return errRoleNilPtr - } - *x = Role(*v) - case *uint: - if v == nil { - return errRoleNilPtr - } - *x = Role(*v) - case *uint64: - if v == nil { - return errRoleNilPtr - } - *x = Role(*v) - case *string: - if v == nil { - return errRoleNilPtr - } - *x, err = ParseRole(*v) - } - - return -} - -// Value implements the driver Valuer interface. -func (x Role) Value() (driver.Value, error) { - return x.String(), nil -} diff --git a/pkg/user/router.go b/pkg/user/router.go index 15744ac..624087f 100644 --- a/pkg/user/router.go +++ b/pkg/user/router.go @@ -3,14 +3,6 @@ package user import "github.com/labstack/echo/v4" func SetUserHandlerRoutes(server *echo.Echo, handler *UserHandler) { - registerProviders() - server.Use(AuthMiddleware(handler.service)) - auth := server.Group("/auth") - - auth.GET("", handler.AuthLogin) - auth.GET("/logout", handler.AuthLogout) - auth.GET("/callback", handler.AuthCallback) - grp := server.Group("/users") grp.GET("", handler.ListAll) } diff --git a/pkg/user/service.go b/pkg/user/service.go index 051acc8..7bb94cb 100644 --- a/pkg/user/service.go +++ b/pkg/user/service.go @@ -5,14 +5,15 @@ import ( "fmt" "log/slog" - "github.com/markbates/goth" "gorm.io/gorm" ) type Service interface { - Auth(ctx context.Context, oauthUser goth.User) (user User, err error) + Create(ctx context.Context, u User) (user User, err error) Get(ctx context.Context, userID uint) (user User, err error) + GetByEmail(ctx context.Context, email string) (user User, err error) ListAll(ctx context.Context, role Role) (users []User, err error) + UpdateAvatar(ctx context.Context, userID uint, newAvatarHref string) (user User, err error) } type UserService struct { @@ -25,71 +26,12 @@ func NewUserService(db *gorm.DB) Service { } } -func (s *UserService) Auth(ctx context.Context, oauthUser goth.User) (user User, err error) { - var oauth OAuth - if err = s.db.WithContext(ctx). - Where("provider = ?", oauthUser.Provider). - Where("identifier = ?", oauthUser.UserID). - First(&oauth).Error; err != nil && err != gorm.ErrRecordNotFound { - slog.ErrorContext(ctx, "Unexpected error searching for oauth link", "provider", oauthUser.Provider, "error", err.Error()) - return user, err +func (s *UserService) Create(ctx context.Context, u User) (user User, err error) { + if err = s.db.WithContext(ctx).Create(&u).Error; err != nil { + slog.ErrorContext(ctx, "Fail to create new user", "error", err.Error(), "user", u) + return u, err } - - // If the provider and id does not match to any one already on the database, we need to link with an user if the email already exists, if not we need to create a new user and link it. - if err == gorm.ErrRecordNotFound { - var provider Provider - slog.WarnContext(ctx, fmt.Sprintf("We didn't found a oauth link for email %s and provider %s", oauthUser.Email, oauthUser.Provider)) - if err = s.db.WithContext(ctx).Where("email = ?", oauthUser.Email).First(&user).Error; err != nil { - if err != gorm.ErrRecordNotFound { - slog.ErrorContext(ctx, "Unexpected error searching for user", "error", err.Error()) - return user, err - } - - slog.WarnContext(ctx, "No user found to this oauth link, creating a new one") - user = User{ - Email: oauthUser.Email, - Name: oauthUser.Name, - Avatar: oauthUser.AvatarURL, - Bio: oauthUser.Description, - } - if err = s.db.WithContext(ctx).Create(&user).Error; err != nil { - slog.ErrorContext(ctx, "Fail to create new user", "error", err.Error()) - return user, err - } - } - - if provider, err = ParseProvider(oauthUser.Provider); err != nil { - slog.ErrorContext(ctx, fmt.Sprintf("Unexpected provider %s", oauthUser.Provider), "error", err.Error()) - return user, err - } - - slog.InfoContext(ctx, fmt.Sprintf("Linking user %d to oauth provider %s of identifier %s", user.ID, oauthUser.Provider, oauthUser.UserID)) - oauth = OAuth{ - UserID: user.ID, - Provider: provider, - Identifier: oauthUser.UserID, - } - if err = s.db.WithContext(ctx).Create(&oauth).Error; err != nil { - slog.ErrorContext(ctx, "Fail to create link of oauth user", "user", user.ID, "error", err.Error()) - return user, err - } - return user, err - } - - // If oauth is linked with a user just return the user - if user, err = s.Get(ctx, oauth.UserID); err != nil { - return - } - - go func() { - if user.Avatar != oauthUser.AvatarURL { - user.Avatar = oauthUser.AvatarURL - if errI := s.db.Where("id = ?", user.ID).Updates(&user).Error; errI != nil { - slog.ErrorContext(ctx, "Unable to update users avatar!", "user", user.ID, "error", errI.Error()) - } - } - }() - return + return u, err } func (s *UserService) Get(ctx context.Context, userID uint) (user User, err error) { @@ -99,9 +41,27 @@ func (s *UserService) Get(ctx context.Context, userID uint) (user User, err erro return user, err } +func (s *UserService) GetByEmail(ctx context.Context, email string) (user User, err error) { + if err = s.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil { + slog.ErrorContext(ctx, "Unable to find user!", "email", email, "error", err.Error()) + } + return user, err +} + func (s *UserService) ListAll(ctx context.Context, role Role) (users []User, err error) { if err = s.db.WithContext(ctx).Model(new(User)).Where("role = ?", role.String()).Scan(&users).Error; err != nil { slog.ErrorContext(ctx, fmt.Sprintf("Fail to list users of role %s", role), "error", err.Error()) } return users, err } + +func (s *UserService) UpdateAvatar(ctx context.Context, userID uint, newAvatarHref string) (user User, err error) { + if err = s.db.WithContext(ctx).Where("id = ?", userID).First(&user).Error; err != nil { + slog.ErrorContext(ctx, "Unable to find user!", "user", userID, "error", err.Error()) + } + user.Avatar = newAvatarHref + if errI := s.db.WithContext(ctx).Where("id = ?", user.ID).Updates(&user).Error; errI != nil { + slog.ErrorContext(ctx, "Unable to update users avatar!", "user", user.ID, "error", errI.Error()) + } + return user, err +} diff --git a/pkg/user/models.go b/pkg/user/user.go similarity index 55% rename from pkg/user/models.go rename to pkg/user/user.go index 737edf0..49af94f 100644 --- a/pkg/user/models.go +++ b/pkg/user/user.go @@ -2,7 +2,7 @@ package user import "gorm.io/gorm" -//go:generate go-enum --marshal --sql -f models.go +//go:generate go-enum --marshal --sql -f user.go type User struct { gorm.Model @@ -16,19 +16,5 @@ type User struct { func (u *User) IsAdmin() bool { return u.Role == RoleAdmin } func (u *User) IsMod() bool { return u.Role == RoleMod } -type OAuth struct { - gorm.Model - UserID uint - Provider Provider - Identifier string - - User User -} - -func (OAuth) TableName() string { return "oauths" } - // ENUM(user, mod, admin) type Role int - -// ENUM(github) -type Provider int diff --git a/pkg/user/user_enum.go b/pkg/user/user_enum.go new file mode 100644 index 0000000..7e3c155 --- /dev/null +++ b/pkg/user/user_enum.go @@ -0,0 +1,150 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: +// Revision: +// Build Date: +// Built By: + +package user + +import ( + "database/sql/driver" + "errors" + "fmt" +) + +const ( + // RoleUser is a Role of type User. + RoleUser Role = iota + // RoleMod is a Role of type Mod. + RoleMod + // RoleAdmin is a Role of type Admin. + RoleAdmin +) + +var ErrInvalidRole = errors.New("not a valid Role") + +const _RoleName = "usermodadmin" + +var _RoleMap = map[Role]string{ + RoleUser: _RoleName[0:4], + RoleMod: _RoleName[4:7], + RoleAdmin: _RoleName[7:12], +} + +// String implements the Stringer interface. +func (x Role) String() string { + if str, ok := _RoleMap[x]; ok { + return str + } + return fmt.Sprintf("Role(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x Role) IsValid() bool { + _, ok := _RoleMap[x] + return ok +} + +var _RoleValue = map[string]Role{ + _RoleName[0:4]: RoleUser, + _RoleName[4:7]: RoleMod, + _RoleName[7:12]: RoleAdmin, +} + +// ParseRole attempts to convert a string to a Role. +func ParseRole(name string) (Role, error) { + if x, ok := _RoleValue[name]; ok { + return x, nil + } + return Role(0), fmt.Errorf("%s is %w", name, ErrInvalidRole) +} + +// MarshalText implements the text marshaller method. +func (x Role) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *Role) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := ParseRole(name) + if err != nil { + return err + } + *x = tmp + return nil +} + +var errRoleNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *Role) Scan(value interface{}) (err error) { + if value == nil { + *x = Role(0) + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x = Role(v) + case string: + *x, err = ParseRole(v) + case []byte: + *x, err = ParseRole(string(v)) + case Role: + *x = v + case int: + *x = Role(v) + case *Role: + if v == nil { + return errRoleNilPtr + } + *x = *v + case uint: + *x = Role(v) + case uint64: + *x = Role(v) + case *int: + if v == nil { + return errRoleNilPtr + } + *x = Role(*v) + case *int64: + if v == nil { + return errRoleNilPtr + } + *x = Role(*v) + case float64: // json marshals everything as a float64 if it's a number + *x = Role(v) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil { + return errRoleNilPtr + } + *x = Role(*v) + case *uint: + if v == nil { + return errRoleNilPtr + } + *x = Role(*v) + case *uint64: + if v == nil { + return errRoleNilPtr + } + *x = Role(*v) + case *string: + if v == nil { + return errRoleNilPtr + } + *x, err = ParseRole(*v) + } + + return +} + +// Value implements the driver Valuer interface. +func (x Role) Value() (driver.Value, error) { + return x.String(), nil +}