Skip to content

Commit

Permalink
add banhammer and user status that defaults to pending (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
broneks authored Dec 1, 2024
1 parent 8058193 commit a507254
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 16 deletions.
25 changes: 25 additions & 0 deletions api/middleware/ban_hammer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package middleware

import (
"fmt"
"net/http"
"piccolo/api/service/banhammerservice"

"github.com/labstack/echo/v4"
)

func BanHammer(banHammerService *banhammerservice.BanHammerService) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
ip := c.RealIP()

banned, ttl := banHammerService.IsBanned(ctx, ip)
if banned {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("You are banned for %s", ttl))
}

return next(c)
}
}
}
5 changes: 5 additions & 0 deletions api/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@ type User struct {
Id pgtype.Text `json:"id"`
Username pgtype.Text `json:"username"`
Email pgtype.Text `json:"email"`
Status pgtype.Text `json:"-"`
Hash pgtype.Text `json:"-"`
HashedAt pgtype.Timestamptz `json:"-"`
LastLoginAt pgtype.Timestamptz `json:"-"`
CreatedAt pgtype.Timestamptz `json:"createdAt"`
UpdatedAt pgtype.Timestamptz `json:"-"`
}

func (user *User) IsActive() bool {
return user.Status.String == "active"
}
2 changes: 2 additions & 0 deletions api/repo/userrepo/get_by_email.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func (repo *UserRepo) GetByEmail(ctx context.Context, email string) (*model.User
id,
username,
email,
status,
hash,
hashed_at,
last_login_at,
Expand All @@ -26,6 +27,7 @@ func (repo *UserRepo) GetByEmail(ctx context.Context, email string) (*model.User
&user.Id,
&user.Username,
&user.Email,
&user.Status,
&user.Hash,
&user.HashedAt,
&user.LastLoginAt,
Expand Down
2 changes: 2 additions & 0 deletions api/repo/userrepo/get_by_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func (repo *UserRepo) GetById(ctx context.Context, id string) (*model.User, erro
id,
username,
email,
status,
hash,
hashed_at,
last_login_at,
Expand All @@ -26,6 +27,7 @@ func (repo *UserRepo) GetById(ctx context.Context, id string) (*model.User, erro
&user.Id,
&user.Username,
&user.Email,
&user.Status,
&user.Hash,
&user.HashedAt,
&user.LastLoginAt,
Expand Down
10 changes: 10 additions & 0 deletions api/resource/auth/handle_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type LoginReq struct {

func (mod *AuthModule) loginHandler(c echo.Context) error {
ctx := c.Request().Context()
ip := c.RealIP()
req := new(LoginReq)

var err error
Expand Down Expand Up @@ -43,19 +44,28 @@ func (mod *AuthModule) loginHandler(c echo.Context) error {
}

if user == nil {
mod.banHammerService.RecordFailedAttempt(ctx, ip)
return c.JSON(http.StatusBadRequest, types.SuccessRes{
Success: false,
Message: "Invalid email or password.",
})
}

if !mod.authService.VerifyPassword(user.Hash.String, req.Password) {
mod.banHammerService.RecordFailedAttempt(ctx, ip)
return c.JSON(http.StatusBadRequest, types.SuccessRes{
Success: false,
Message: "Invalid email or password.",
})
}

if !user.IsActive() {
return c.JSON(http.StatusBadRequest, types.SuccessRes{
Success: false,
Message: "User is not active.",
})
}

if err = mod.userRepo.UpdateLastLoginAt(ctx, user.Id.String); err != nil {
mod.server.Logger.Error(err.Error())
}
Expand Down
22 changes: 15 additions & 7 deletions api/resource/auth/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@ package auth
import (
"piccolo/api/repo/userrepo"
"piccolo/api/service/authservice"
"piccolo/api/service/banhammerservice"
"piccolo/api/types"
)

type AuthModule struct {
server *types.Server
userRepo *userrepo.UserRepo
authService *authservice.AuthService
server *types.Server
userRepo *userrepo.UserRepo
banHammerService *banhammerservice.BanHammerService
authService *authservice.AuthService
}

func NewModule(server *types.Server, userRepo *userrepo.UserRepo, authService *authservice.AuthService) *AuthModule {
func NewModule(
server *types.Server,
userRepo *userrepo.UserRepo,
banHammerService *banhammerservice.BanHammerService,
authService *authservice.AuthService,
) *AuthModule {
return &AuthModule{
server: server,
userRepo: userRepo,
authService: authService,
server: server,
userRepo: userRepo,
banHammerService: banHammerService,
authService: authService,
}
}
8 changes: 2 additions & 6 deletions api/resource/auth/routes.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
package auth

import (
"os"
"piccolo/api/middleware"

"github.com/labstack/echo/v4"
)

func (mod *AuthModule) Routes(g *echo.Group) {
auth := g.Group("/auth")

if os.Getenv("ENV") != "local" {
return
}

auth.POST("/register", mod.registerHandler)
auth.POST("/login", mod.loginHandler)
auth.POST("/login", mod.loginHandler, middleware.BanHammer(mod.banHammerService))

auth.POST("/refresh", mod.refreshHandler)
auth.POST("/logout", mod.logoutHandler)
Expand Down
4 changes: 3 additions & 1 deletion api/resource/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"piccolo/api/resource/auth"
"piccolo/api/resource/photo"
"piccolo/api/service/authservice"
"piccolo/api/service/banhammerservice"
"piccolo/api/service/photoservice"
"piccolo/api/types"

Expand All @@ -24,10 +25,11 @@ func Routes(g *echo.Group, server *types.Server) {
photoRepo := photorepo.New(server.DB)
albumRepo := albumrepo.New(server.DB)

banHammerService := banhammerservice.New()
authService := authservice.New(server, userRepo)
photoService := photoservice.New(server, photoRepo)

authModule := auth.NewModule(server, userRepo, authService)
authModule := auth.NewModule(server, userRepo, banHammerService, authService)
authModule.Routes(v1)

photoModule := photo.NewModule(server, photoRepo, photoService)
Expand Down
4 changes: 2 additions & 2 deletions api/security/redis_rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"log"
"log/slog"
"os"

redisrate "github.com/go-redis/redis_rate/v10"
Expand All @@ -20,7 +20,7 @@ const rateRequest = "rate_request_%s"
func NewRedisRateLimiter() *RedisRateLimiter {
opts, err := redis.ParseURL(os.Getenv("REDIS_URL"))
if err != nil {
log.Fatalf("cannot create redis connection: %v", err)
slog.Error("cannot create redis connection", "error", err)
}

rdb := redis.NewClient(opts)
Expand Down
20 changes: 20 additions & 0 deletions api/service/banhammerservice/is_banned.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package banhammerservice

import (
"context"
"log/slog"
"time"

"github.com/redis/go-redis/v9"
)

func (svc *BanHammerService) IsBanned(ctx context.Context, ip string) (bool, time.Duration) {
key := banKeyPrefix + ip

ttl, err := svc.rdb.TTL(ctx, key).Result()
if err != nil && err != redis.Nil {
slog.Error("Error checking ban status:", "error", err)
}

return ttl > 0, ttl
}
35 changes: 35 additions & 0 deletions api/service/banhammerservice/record_failed_attempt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package banhammerservice

import (
"context"
"log/slog"
)

func (svc *BanHammerService) RecordFailedAttempt(ctx context.Context, ip string) error {
attemptsKey := attemptKeyPrefix + ip
banKey := banKeyPrefix + ip

// Increment the failed attempts count atomically
attempts, err := svc.rdb.Incr(ctx, attemptsKey).Result()
if err != nil {
return err
}

// Set expiration for the tracking key if it’s the first attempt
if attempts == 1 {
svc.rdb.Expire(ctx, attemptsKey, trackingTime)
}

// Check if the IP should be banned
if attempts >= maxAttempts {
err := svc.rdb.Set(ctx, banKey, "1", banDuration).Err() // Ban the IP
if err != nil {
return err
}
// Cleanup the attempts key
svc.rdb.Del(ctx, attemptsKey)
slog.Info("IP %s has been banned for %s\n", ip, banDuration)
}

return nil
}
32 changes: 32 additions & 0 deletions api/service/banhammerservice/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package banhammerservice

import (
"log/slog"
"os"
"time"

"github.com/redis/go-redis/v9"
)

type BanHammerService struct {
rdb *redis.Client
}

const (
maxAttempts = 10
banDuration = 1 * time.Hour
trackingTime = 10 * time.Minute
banKeyPrefix = "ban-hammer:ban:"
attemptKeyPrefix = "ban-hammer:attempt:"
)

func New() *BanHammerService {
opts, err := redis.ParseURL(os.Getenv("REDIS_URL"))
if err != nil {
slog.Error("cannot create redis connection", "error", err)
}

rdb := redis.NewClient(opts)

return &BanHammerService{rdb}
}
2 changes: 2 additions & 0 deletions db/migrations/000014_update_users_add_status.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
alter table users
drop column if exists status;
4 changes: 4 additions & 0 deletions db/migrations/000014_update_users_add_status.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
create type user_status_enum as enum ('pending', 'active', 'suspended');

alter table users
add column status user_status_enum default 'pending';

0 comments on commit a507254

Please sign in to comment.