diff --git a/go.mod b/go.mod index 4d03812..4d5250d 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.20 github.com/spf13/cobra v1.8.0 - github.com/swaggo/swag v1.16.1 + github.com/swaggo/swag v1.16.3 github.com/testcontainers/testcontainers-go v0.30.0 github.com/testcontainers/testcontainers-go/modules/mysql v0.27.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.30.0 @@ -37,6 +37,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/go-chi/chi/v5 v5.0.12 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -66,6 +67,8 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/swaggo/files v1.0.1 // indirect + github.com/swaggo/http-swagger v1.3.4 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/numcpus v0.7.0 // indirect github.com/urfave/cli/v2 v2.25.5 // indirect @@ -91,10 +94,10 @@ require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/lipgloss v0.8.0 - github.com/go-openapi/jsonpointer v0.20.0 // indirect - github.com/go-openapi/jsonreference v0.20.2 // indirect - github.com/go-openapi/spec v0.20.9 // indirect - github.com/go-openapi/swag v0.22.4 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect + github.com/go-openapi/spec v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.8 // indirect diff --git a/go.sum b/go.sum index a3df22e..5ca2c81 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= +github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -71,18 +73,26 @@ github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34 github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/jsonpointer v0.20.0 h1:ESKJdU9ASRfaPNOPRx12IUyA1vn3R9GiE3KYD14BXdQ= github.com/go-openapi/jsonpointer v0.20.0/go.mod h1:6PGzBjjIIumbLYysB73Klnms1mwnU4G3YHOECG3CedA= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= github.com/go-openapi/jsonreference v0.20.0/go.mod h1:Ag74Ico3lPc+zR+qjn4XBUmXymS4zJbYVCZmcgkasdo= github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= github.com/go-openapi/spec v0.20.9 h1:xnlYNQAwKd2VQRRfwTEI0DcK+2cbuvI/0c7jx3gA8/8= github.com/go-openapi/spec v0.20.9/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA= +github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= +github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -242,10 +252,16 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= +github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg= github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw= github.com/swaggo/files/v2 v2.0.0/go.mod h1:24kk2Y9NYEJ5lHuCra6iVwkMjIekMCaFq/0JQj66kyM= +github.com/swaggo/http-swagger v1.3.4 h1:q7t/XLx0n15H1Q9/tk3Y9L4n210XzJF5WtnDX64a5ww= +github.com/swaggo/http-swagger v1.3.4/go.mod h1:9dAh0unqMBAlbp1uE2Uc2mQTxNMU/ha4UbucIg1MFkQ= github.com/swaggo/swag v1.16.1 h1:fTNRhKstPKxcnoKsytm4sahr8FaYzUcT7i1/3nd/fBg= github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKikGxto= +github.com/swaggo/swag v1.16.3 h1:PnCYjPCah8FK4I26l2F/KQ4yz3sILcVUN3cTlBFA9Pg= +github.com/swaggo/swag v1.16.3/go.mod h1:DImHIuOFXKpMFAQjcC7FG4m3Dg4+QuUgUzJmKjI/gRk= github.com/testcontainers/testcontainers-go v0.30.0 h1:jmn/XS22q4YRrcMwWg0pAwlClzs/abopbsBzrepyc4E= github.com/testcontainers/testcontainers-go v0.30.0/go.mod h1:K+kHNGiM5zjklKjgTtcrEetF3uhWbMUyqAQoyoh8Pf0= github.com/testcontainers/testcontainers-go/modules/mysql v0.27.0 h1:6p/o/bAZPcFiBWTd71umQmj/i4L6ipVK3B2ZJBqn5HM= @@ -321,6 +337,7 @@ golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1 golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= diff --git a/internal/handlers/dbHandlers.go b/internal/handlers/dbHandlers.go index f8250cc..cdef022 100644 --- a/internal/handlers/dbHandlers.go +++ b/internal/handlers/dbHandlers.go @@ -1,7 +1,9 @@ package handlers import ( - "github.com/gofiber/fiber/v2" + "net/http" + + "github.com/go-chi/chi/v5" "github.com/kareemmahlees/meta-x/internal/db" "github.com/kareemmahlees/meta-x/lib" "github.com/kareemmahlees/meta-x/models" @@ -11,15 +13,15 @@ type DBHandler struct { storage db.DatabaseExecuter } -// TODO: change the interface func NewDBHandler(storage db.DatabaseExecuter) *DBHandler { return &DBHandler{storage} } -func (dh *DBHandler) RegisterRoutes(app *fiber.App) { - dbGroup := app.Group("database") - dbGroup.Get("", dh.handleListDatabases) - dbGroup.Post("", dh.handleCreateDatabase) +func (dh *DBHandler) RegisterRoutes(r *chi.Mux) { + r.Route("/database", func(r chi.Router) { + r.Get("/", dh.handleListDatabases) + r.Post("/", dh.handleCreateDatabase) + }) } // Lists databases @@ -29,30 +31,34 @@ func (dh *DBHandler) RegisterRoutes(app *fiber.App) { // @router /database [get] // @produce json // @success 200 {object} models.ListDatabasesResp -func (dh *DBHandler) handleListDatabases(c *fiber.Ctx) error { +func (dh *DBHandler) handleListDatabases(w http.ResponseWriter, r *http.Request) { dbs, err := dh.storage.ListDBs() if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.JSON(models.ListDatabasesResp{Databases: dbs}) + writeJson(w, models.ListDatabasesResp{Databases: dbs}) } -func (dh *DBHandler) handleCreateDatabase(c *fiber.Ctx) error { +func (dh *DBHandler) handleCreateDatabase(w http.ResponseWriter, r *http.Request) { var payload models.CreatePgMySqlDBPayload - if err := c.BodyParser(&payload); err != nil { - return lib.UnprocessableEntityErr(c, err.Error()) + if err := parseBody(r.Body, &payload); err != nil { + httpError(w, http.StatusUnprocessableEntity, err.Error()) + return } if errs := lib.ValidateStruct(payload); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + httpError(w, http.StatusBadRequest, errs) + return } - err := dh.storage.CreateDB(payload.Name) - - if err != nil { - return lib.InternalServerErr(c, err.Error()) + if err := dh.storage.CreateDB(payload.Name); err != nil { + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.Status(201).JSON(models.SuccessResp{Success: true}) + + w.WriteHeader(http.StatusCreated) + writeJson(w, models.SuccessResp{Success: true}) } diff --git a/internal/handlers/dbHandlers_test.go b/internal/handlers/dbHandlers_test.go index c36bf1f..178b680 100644 --- a/internal/handlers/dbHandlers_test.go +++ b/internal/handlers/dbHandlers_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/gofiber/fiber/v2" + "github.com/go-chi/chi/v5" "github.com/kareemmahlees/meta-x/models" "github.com/kareemmahlees/meta-x/utils" "github.com/stretchr/testify/suite" @@ -27,49 +27,37 @@ func (md *MockDBExecutor) CreateDB(dbName string) error { type DBHandlerTestSuite struct { suite.Suite - app *fiber.App + r *chi.Mux + handler *DBHandler } func (suite *DBHandlerTestSuite) SetupSuite() { - app := fiber.New() + r := chi.NewRouter() storage := NewMockDBExecutor() handler := NewDBHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - suite.app = app + suite.r = r + suite.handler = handler } func (suite *DBHandlerTestSuite) TestRegisterRoutes() { assert := suite.Assert() - var routes []utils.FiberRoute - for _, route := range suite.app.GetRoutes() { - routes = append(routes, utils.FiberRoute{ - Method: route.Method, - Path: route.Path, - }) + var routes []string + for _, route := range suite.r.Routes() { + routes = append(routes, route.Pattern) } - assert.Contains(routes, utils.FiberRoute{ - Method: "GET", - Path: "/database", - }) - + assert.Contains(routes, "/database/*") } func (suite *DBHandlerTestSuite) TestHandleListDatabases() { assert := suite.Assert() - req := httptest.NewRequest("GET", "http://localhost:5522/database", nil) - - resp, _ := suite.app.Test(req) - defer resp.Body.Close() - payload := utils.DecodeBody[models.ListDatabasesResp](resp.Body) - - assert.Equal(resp.StatusCode, fiber.StatusOK) - assert.NotEmpty(payload.Databases) - + assert.HTTPSuccess(suite.handler.handleListDatabases, http.MethodGet, "/database", nil) + assert.HTTPBodyContains(suite.handler.handleListDatabases, http.MethodGet, "/database", nil, "test") } func (suite *DBHandlerTestSuite) TestHandleCreateDatabase() { @@ -80,40 +68,36 @@ func (suite *DBHandlerTestSuite) TestHandleCreateDatabase() { passingBody, _ := utils.EncodeBody(models.CreatePgMySqlDBPayload{ Name: "testing", }) - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/database", - ReqBody: passingBody, - } - - decodedRes, rawRes := passing.RunRequest(suite.app) - assert.Equal(rawRes.StatusCode, fiber.StatusCreated) + req, _ := http.NewRequest(http.MethodPost, "/database", passingBody) + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(suite.handler.handleCreateDatabase) + handler.ServeHTTP(rr, req) + assert.Equal(rr.Code, http.StatusCreated) + + decodedRes := utils.DecodeBody[models.SuccessResp](rr.Result().Body) assert.True(decodedRes.Success) + }) t.Run("should fail unproccessable entity", func(t *testing.T) { - failingUnprocessableEntity := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/database", - } - decodedRes, rawRes := failingUnprocessableEntity.RunRequest(suite.app) - assert.Equal(http.StatusUnprocessableEntity, rawRes.StatusCode) - assert.Contains(decodedRes.Message, "Unprocessable Entity") + assert.HTTPError(suite.handler.handleCreateDatabase, http.MethodPost, "/database", nil) + }) + t.Run("should fail bad request", func(t *testing.T) { failingBadRequestBody, _ := utils.EncodeBody(models.CreatePgMySqlDBPayload{ Name: "", }) - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/database", - ReqBody: failingBadRequestBody, - } - decodedRes, rawRes = failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + req, _ := http.NewRequest(http.MethodPost, "/database", failingBadRequestBody) + rr := httptest.NewRecorder() - }) + handler := http.HandlerFunc(suite.handler.handleCreateDatabase) + handler.ServeHTTP(rr, req) + assert.Equal(rr.Code, http.StatusBadRequest) + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.Len(decodedRes.Message, 1) + }) } func TestDBHandlerTestSuite(t *testing.T) { diff --git a/internal/handlers/defaultHandler.go b/internal/handlers/defaultHandler.go index f5f4fe0..40994ef 100644 --- a/internal/handlers/defaultHandler.go +++ b/internal/handlers/defaultHandler.go @@ -1,23 +1,21 @@ package handlers import ( + "net/http" "time" - "github.com/gofiber/fiber/v2" - "github.com/kareemmahlees/meta-x/internal/db" + "github.com/go-chi/chi/v5" ) -type DefaultHandler struct { - storage *db.Storage -} +type DefaultHandler struct{} -func NewDefaultHandler(storage *db.Storage) *DefaultHandler { - return &DefaultHandler{storage} +func NewDefaultHandler() *DefaultHandler { + return &DefaultHandler{} } -func (h *DefaultHandler) RegisterRoutes(app *fiber.App) { - app.Get("/health", h.healthCheck) - app.Get("/", h.apiInfo) +func (h *DefaultHandler) RegisterRoutes(r *chi.Mux) { + r.Get("/health", h.healthCheck) + r.Get("/", h.apiInfo) } type HealthCheckResult struct { @@ -31,15 +29,17 @@ type HealthCheckResult struct { // @tags default // @router /health [get] // @success 200 {object} HealthCheckResult -func (h *DefaultHandler) healthCheck(c *fiber.Ctx) error { - return c.JSON(fiber.Map{"date": time.Now()}) +func (h *DefaultHandler) healthCheck(w http.ResponseWriter, r *http.Request) { + writeJson(w, map[string]time.Time{ + "date": time.Now(), + }) } type APIInfoResult struct { - Author string - Year int - Contact string - Repo string + Author string `json:"author"` + Year int `json:"yeaer"` + Contact string `json:"contact"` + Repo string `json:"repo"` } // Get info about the api @@ -49,11 +49,11 @@ type APIInfoResult struct { // @tags default // @router / [get] // @success 200 {object} APIInfoResult -func (h *DefaultHandler) apiInfo(c *fiber.Ctx) error { - return c.JSON(fiber.Map{ - "author": "Kareem Ebrahim", - "year": 2023, - "contact": "kareemmahlees@gmail.com", - "repo": "https://github.com/kareemmahlees/meta-x", +func (h *DefaultHandler) apiInfo(w http.ResponseWriter, r *http.Request) { + writeJson(w, APIInfoResult{ + Author: "Kareem Ebrahim", + Year: 2024, + Contact: "kareemmahlees@gmail.com", + Repo: "https://github.com/kareemmahlees/meta-x", }) } diff --git a/internal/handlers/defaultHandler_test.go b/internal/handlers/defaultHandler_test.go index 53f1750..e9abab3 100644 --- a/internal/handlers/defaultHandler_test.go +++ b/internal/handlers/defaultHandler_test.go @@ -1,62 +1,53 @@ package handlers import ( - "net/http/httptest" + "net/http" "testing" - "github.com/gofiber/fiber/v2" - "github.com/kareemmahlees/meta-x/utils" - "github.com/stretchr/testify/assert" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/suite" ) -func TestRegisterDefaultRoutes(t *testing.T) { - - app := fiber.New() - handler := NewDefaultHandler(nil) - - handler.RegisterRoutes(app) - - var routes []utils.FiberRoute - for _, route := range app.GetRoutes() { - routes = append(routes, utils.FiberRoute{ - Method: route.Method, - Path: route.Path, - }) - } +type DefaultHandlerTestSuite struct { + suite.Suite + r *chi.Mux + handler *DefaultHandler +} - assert.Contains(t, routes, utils.FiberRoute{ - Method: "GET", - Path: "/health", - }) +func (suite *DefaultHandlerTestSuite) SetupSuite() { + r := chi.NewRouter() + handler := NewDefaultHandler() + handler.RegisterRoutes(r) + suite.r = r + suite.handler = handler } -func TestHealthCheck(t *testing.T) { - app := fiber.New() - handler := NewDefaultHandler(nil) - handler.RegisterRoutes(app) +func (suite *DefaultHandlerTestSuite) TestRegisterDefaultRoutes() { + assert := suite.Assert() - req := httptest.NewRequest("GET", "http://localhost:4000/health", nil) + var routes []string + for _, route := range suite.r.Routes() { + routes = append(routes, route.Pattern) + } - resp, _ := app.Test(req) - payload := utils.DecodeBody[map[string]any](resp.Body) + assert.Contains(routes, "/health") + assert.Contains(routes, "/") +} - assert.Equal(t, resp.StatusCode, fiber.StatusOK) +func (suite *DefaultHandlerTestSuite) TestHealthCheck() { + assert := suite.Assert() - _, ok := payload["date"] - assert.True(t, ok) + assert.HTTPSuccess(suite.handler.healthCheck, http.MethodGet, "/health", nil) + assert.HTTPBodyContains(suite.handler.healthCheck, http.MethodGet, "/health", nil, "date") } -func TestBaseUrl(t *testing.T) { - app := fiber.New() - handler := NewDefaultHandler(nil) - handler.RegisterRoutes(app) +func (suite *DefaultHandlerTestSuite) TestAPIInfo() { + assert := suite.Assert() - req := httptest.NewRequest("GET", "http://localhost:4000", nil) - - resp, err := app.Test(req) - assert.Nil(t, err) - _ = utils.DecodeBody[map[string]any](resp.Body) + assert.HTTPSuccess(suite.handler.apiInfo, http.MethodGet, "/", nil) +} - assert.Equal(t, resp.StatusCode, fiber.StatusOK) +func TestDefaultHandlerTestSuite(t *testing.T) { + suite.Run(t, new(DefaultHandlerTestSuite)) } diff --git a/internal/handlers/encoding.go b/internal/handlers/encoding.go new file mode 100644 index 0000000..3854c74 --- /dev/null +++ b/internal/handlers/encoding.go @@ -0,0 +1,24 @@ +package handlers + +import ( + "encoding/json" + "io" + "net/http" +) + +func writeJson[T any](w http.ResponseWriter, payload T) { + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(payload); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func parseBody[T any](body io.Reader, parsedResult T) error { + + if err := json.NewDecoder(body).Decode(parsedResult); err != nil { + return err + } + + return nil +} diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go new file mode 100644 index 0000000..f26a29d --- /dev/null +++ b/internal/handlers/errors.go @@ -0,0 +1,16 @@ +package handlers + +import ( + "net/http" + + "github.com/kareemmahlees/meta-x/models" +) + +func httpError(w http.ResponseWriter, code int, errMsg any) { + w.WriteHeader(code) + + writeJson(w, models.ErrResp{ + Code: code, + Message: errMsg, + }) +} diff --git a/internal/handlers/errors_test.go b/internal/handlers/errors_test.go new file mode 100644 index 0000000..90d89be --- /dev/null +++ b/internal/handlers/errors_test.go @@ -0,0 +1,21 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/kareemmahlees/meta-x/models" + "github.com/kareemmahlees/meta-x/utils" + "github.com/stretchr/testify/assert" +) + +func TestHttpError(t *testing.T) { + rr := httptest.NewRecorder() + httpError(rr, http.StatusOK, "something") + + assert.Equal(t, rr.Code, http.StatusOK) + + body := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.Equal(t, body.Message, "something") +} diff --git a/internal/handlers/tableHandlers.go b/internal/handlers/tableHandlers.go index dff5e7a..8b71e92 100644 --- a/internal/handlers/tableHandlers.go +++ b/internal/handlers/tableHandlers.go @@ -1,7 +1,9 @@ package handlers import ( - "github.com/gofiber/fiber/v2" + "net/http" + + "github.com/go-chi/chi/v5" "github.com/kareemmahlees/meta-x/internal/db" "github.com/kareemmahlees/meta-x/lib" "github.com/kareemmahlees/meta-x/models" @@ -15,15 +17,16 @@ func NewTableHandler(storage db.TableExecuter) *TableHandler { return &TableHandler{storage} } -func (th *TableHandler) RegisterRoutes(app *fiber.App) { - tableGroup := app.Group("table") - tableGroup.Get("", th.handleListTables) - tableGroup.Get("/:tableName/describe", th.handleGetTableInfo) - tableGroup.Post("/:tableName", th.handleCreateTable) - tableGroup.Delete("/:tableName", th.handleDeleteTable) - tableGroup.Post("/:tableName/column/add", th.handleAddColumn) - tableGroup.Put("/:tableName/column/modify", th.handleModifyColumn) - tableGroup.Delete("/:tableName/column/delete", th.handleDeleteColumn) +func (h *TableHandler) RegisterRoutes(r *chi.Mux) { + r.Route("/table", func(r chi.Router) { + r.Get("/", h.handleListTables) + r.Get("/{tableName}/describe", h.handleGetTableInfo) + r.Post("/{tableName}", h.handleCreateTable) + r.Delete("/{tableName}", h.handleDeleteTable) + r.Post("/{tableName}/column/add", h.handleAddColumn) + r.Put("/{tableName}/column/modify", h.handleModifyColumn) + r.Delete("/{tableName}/column/delete", h.handleDeleteColumn) + }) } // Get detailed info about the specified table @@ -33,20 +36,23 @@ func (th *TableHandler) RegisterRoutes(app *fiber.App) { // @router /table/{tableName}/describe [get] // @produce json // @success 200 {object} []models.TableInfoResp -func (th *TableHandler) handleGetTableInfo(c *fiber.Ctx) error { +func (h *TableHandler) handleGetTableInfo(w http.ResponseWriter, r *http.Request) { params := struct { - TableName string `params:"tableName" validate:"required,alpha"` - }{} - _ = c.ParamsParser(¶ms) + TableName string `validate:"required,alpha"` + }{ + TableName: chi.URLParam(r, "tableName"), + } - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } - tableInfo, err := th.storage.GetTable(params.TableName) + tableInfo, err := h.storage.GetTable(params.TableName) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.JSON(tableInfo) + writeJson(w, tableInfo) } // Lists all tables in the database @@ -56,12 +62,13 @@ func (th *TableHandler) handleGetTableInfo(c *fiber.Ctx) error { // @router /table [get] // @produce json // @success 200 {object} models.ListTablesResp -func (th *TableHandler) handleListTables(c *fiber.Ctx) error { - tables, err := th.storage.ListTables() +func (h *TableHandler) handleListTables(w http.ResponseWriter, r *http.Request) { + tables, err := h.storage.ListTables() if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.JSON(models.ListTablesResp{Tables: tables}) + writeJson(w, models.ListTablesResp{Tables: tables}) } // Creates a Table @@ -74,29 +81,34 @@ func (th *TableHandler) handleListTables(c *fiber.Ctx) error { // @accept json // @produce json // @success 201 {object} models.CreateTableResp -func (th *TableHandler) handleCreateTable(c *fiber.Ctx) error { +func (h *TableHandler) handleCreateTable(w http.ResponseWriter, r *http.Request) { params := struct { - TableName string `params:"tableName" validate:"required,alphanum"` - }{} - _ = c.ParamsParser(¶ms) - - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + TableName string `validate:"required,alphanum"` + }{ + TableName: chi.URLParam(r, "tableName"), + } + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } var payload []models.CreateTablePayload - if err := c.BodyParser(&payload); err != nil { - return lib.UnprocessableEntityErr(c, err.Error()) + if err := parseBody(r.Body, &payload); err != nil { + httpError(w, http.StatusUnprocessableEntity, err.Error()) + return } for _, v := range payload { if errs := lib.ValidateStruct(v); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + httpError(w, http.StatusBadRequest, errs) + return } } - err := th.storage.CreateTable(params.TableName, payload) + err := h.storage.CreateTable(params.TableName, payload) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) } - return c.Status(fiber.StatusCreated).JSON(models.CreateTableResp{Created: params.TableName}) + + w.WriteHeader(http.StatusCreated) + writeJson(w, models.CreateTableResp{Created: params.TableName}) } // Updates a table by adding a column @@ -109,26 +121,33 @@ func (th *TableHandler) handleCreateTable(c *fiber.Ctx) error { // @accept json // @produce json // @success 201 {object} models.SuccessResp -func (th *TableHandler) handleAddColumn(c *fiber.Ctx) error { +func (h *TableHandler) handleAddColumn(w http.ResponseWriter, r *http.Request) { params := struct { - TableName string `params:"tableName" validate:"required,alphanum"` - }{} - _ = c.ParamsParser(¶ms) - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + TableName string `validate:"required,alphanum"` + }{ + TableName: chi.URLParam(r, "tableName"), + } + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } var payload models.AddModifyColumnPayload - if err := c.BodyParser(&payload); err != nil { - return lib.UnprocessableEntityErr(c, err.Error()) + if err := parseBody(r.Body, &payload); err != nil { + httpError(w, http.StatusUnprocessableEntity, err.Error()) + return } - if errs := lib.ValidateStruct(payload); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + if errs := lib.ValidateStruct(&payload); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } - err := th.storage.AddColumn(c.Params("tableName"), payload) + err := h.storage.AddColumn(params.TableName, payload) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.Status(fiber.StatusCreated).JSON(models.SuccessResp{Success: true}) + + w.WriteHeader(http.StatusCreated) + writeJson(w, models.SuccessResp{Success: true}) } // Updates a table by modifying a column @@ -141,30 +160,32 @@ func (th *TableHandler) handleAddColumn(c *fiber.Ctx) error { // @accept json // @produce json // @success 200 {object} models.SuccessResp -func (th *TableHandler) handleModifyColumn(c *fiber.Ctx) error { - if c.Locals("provider") == lib.SQLITE3 { - return lib.ForbiddenErr(c, "MODIFY COLUMN not supported by sqlite") - } +func (h *TableHandler) handleModifyColumn(w http.ResponseWriter, r *http.Request) { params := struct { - TableName string `params:"tableName" validate:"required,alphanum"` - }{} + TableName string `validate:"required,alphanum"` + }{ + TableName: chi.URLParam(r, "tableName"), + } - _ = c.ParamsParser(¶ms) - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } var payload models.AddModifyColumnPayload - if err := c.BodyParser(&payload); err != nil { - return lib.UnprocessableEntityErr(c, err.Error()) + if err := parseBody(r.Body, &payload); err != nil { + httpError(w, http.StatusUnprocessableEntity, err.Error()) + return } - if errs := lib.ValidateStruct(payload); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + if errs := lib.ValidateStruct(&payload); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } - err := th.storage.UpdateColumn(c.Params("tableName"), payload) + err := h.storage.UpdateColumn(params.TableName, payload) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.JSON(models.SuccessResp{Success: true}) + writeJson(w, models.SuccessResp{Success: true}) } // Updates a table by deleting/dropping a column @@ -177,27 +198,32 @@ func (th *TableHandler) handleModifyColumn(c *fiber.Ctx) error { // @accept json // @produce json // @success 200 {object} models.SuccessResp -func (th *TableHandler) handleDeleteColumn(c *fiber.Ctx) error { +func (h *TableHandler) handleDeleteColumn(w http.ResponseWriter, r *http.Request) { params := struct { TableName string `params:"tableName" validate:"required,alphanum"` - }{} - _ = c.ParamsParser(¶ms) - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + }{ + TableName: chi.URLParam(r, "tableName"), + } + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } var payload models.DeleteColumnPayload - if err := c.BodyParser(&payload); err != nil { - return lib.UnprocessableEntityErr(c, err.Error()) + if err := parseBody(r.Body, &payload); err != nil { + httpError(w, http.StatusUnprocessableEntity, err.Error()) + return } - if errs := lib.ValidateStruct(payload); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + if errs := lib.ValidateStruct(&payload); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } - err := th.storage.DeleteColumn(params.TableName, payload) + err := h.storage.DeleteColumn(params.TableName, payload) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err) + return } - return c.JSON(models.SuccessResp{Success: true}) + writeJson(w, models.SuccessResp{Success: true}) } // Deletes a table @@ -209,18 +235,21 @@ func (th *TableHandler) handleDeleteColumn(c *fiber.Ctx) error { // @accept json // @produce json // @success 200 {object} models.SuccessResp -func (th *TableHandler) handleDeleteTable(c *fiber.Ctx) error { +func (h *TableHandler) handleDeleteTable(w http.ResponseWriter, r *http.Request) { params := struct { TableName string `params:"tableName" validate:"required,alpha"` - }{} - _ = c.ParamsParser(¶ms) - if errs := lib.ValidateStruct(params); len(errs) > 0 { - return lib.BadRequestErr(c, errs) + }{ + TableName: chi.URLParam(r, "tableName"), + } + if errs := lib.ValidateStruct(¶ms); len(errs) > 0 { + httpError(w, http.StatusBadRequest, errs) + return } - err := th.storage.DeleteTable(c.Params("tableName")) + err := h.storage.DeleteTable(params.TableName) if err != nil { - return lib.InternalServerErr(c, err.Error()) + httpError(w, http.StatusInternalServerError, err.Error()) + return } - return c.JSON(models.SuccessResp{Success: true}) + writeJson(w, models.SuccessResp{Success: true}) } diff --git a/internal/handlers/tableHandlers_test.go b/internal/handlers/tableHandlers_test.go index 0da485b..fc28c27 100644 --- a/internal/handlers/tableHandlers_test.go +++ b/internal/handlers/tableHandlers_test.go @@ -5,7 +5,7 @@ import ( "net/http" "testing" - "github.com/gofiber/fiber/v2" + "github.com/go-chi/chi/v5" "github.com/kareemmahlees/meta-x/models" "github.com/kareemmahlees/meta-x/utils" "github.com/stretchr/testify/suite" @@ -13,35 +13,30 @@ import ( type TableHandlerTestSuite struct { suite.Suite - app *fiber.App + r *chi.Mux + handler *TableHandler } func (suite *TableHandlerTestSuite) SetupSuite() { - app := fiber.New() + r := chi.NewRouter() storage := NewMockTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - suite.app = app + suite.r = r + suite.handler = handler } func (suite *TableHandlerTestSuite) TestRegisterRoutes() { assert := suite.Assert() - var routes []utils.FiberRoute - for _, route := range suite.app.GetRoutes() { - routes = append(routes, utils.FiberRoute{ - Method: route.Method, - Path: route.Path, - }) + var routes []string + for _, route := range suite.r.Routes() { + routes = append(routes, route.Pattern) } - assert.Contains(routes, utils.FiberRoute{ - Method: "GET", - Path: "/table", - }) - + assert.Contains(routes, "/table/*") } func (suite *TableHandlerTestSuite) TestHandleGetTableInfo() { @@ -49,40 +44,28 @@ func (suite *TableHandlerTestSuite) TestHandleGetTableInfo() { t := suite.T() t.Run("should pass", func(t *testing.T) { - passing := utils.RequestTesting[[]models.TableInfoResp]{ - ReqMethod: http.MethodGet, - ReqUrl: "/table/test/describe", - } - - tableInfo, _ := passing.RunRequest(suite.app) - assert.NotEmpty(tableInfo) - assert.Equal(tableInfo[0].Name, "name") + rr := utils.TestRequest(suite.r, http.MethodGet, "/table/test/describe", http.NoBody) + assert.Equal(rr.Code, http.StatusOK) }) t.Run("should fail bad request", func(t *testing.T) { - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodGet, - ReqUrl: "/table/12345/describe", - } - decoedResp, rawResp := failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawResp.StatusCode) - assert.Len(decoedResp.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodGet, "/table/12345/describe", http.NoBody) + assert.Equal(rr.Code, http.StatusBadRequest) + + decodedResp := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.NotEmpty(decodedResp.Message) }) t.Run("should fail internal server", func(t *testing.T) { - failingInternalServer := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodGet, - ReqUrl: "/table/test/describe", - } - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - _, rawResp := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawResp.StatusCode) + rr := utils.TestRequest(r, http.MethodGet, "/table/test/describe", http.NoBody) + assert.Equal(rr.Code, http.StatusInternalServerError) }) } @@ -91,32 +74,20 @@ func (suite *TableHandlerTestSuite) TestHandleListTables() { t := suite.T() t.Run("should pass", func(t *testing.T) { - passing := utils.RequestTesting[models.ListTablesResp]{ - ReqMethod: http.MethodGet, - ReqUrl: "/table", - } - decoedRes, _ := passing.RunRequest(suite.app) - - tables := utils.SliceOfPointersToSliceOfValues(decoedRes.Tables) - assert.NotEmpty(tables) - assert.Contains(tables, "test") - + assert.HTTPSuccess(suite.handler.handleListTables, http.MethodGet, "/table", nil) + assert.HTTPBodyContains(suite.handler.handleListTables, http.MethodGet, "/table", nil, "test") }) t.Run("should fail internal server", func(t *testing.T) { - failingInternalServer := utils.RequestTesting[models.ListTablesResp]{ - ReqMethod: http.MethodGet, - ReqUrl: "/table", - } - - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - _, rawRes := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawRes.StatusCode) + rr := utils.TestRequest(r, http.MethodGet, "/table", http.NoBody) + + assert.Equal(http.StatusInternalServerError, rr.Code) }) } @@ -131,36 +102,25 @@ func (suite *TableHandlerTestSuite) TestHandleCreateTable() { Default: "kareem", Unique: true, }}) - passing := utils.RequestTesting[models.CreateTableResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test1", - ReqBody: passingBody, - } - decodedResp, rawResp := passing.RunRequest(suite.app) - assert.Equal(http.StatusCreated, rawResp.StatusCode) - assert.Equal(decodedResp.Created, "test1") + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/test1", passingBody) + assert.Equal(http.StatusCreated, rr.Code) + decodedRes := utils.DecodeBody[models.CreateTableResp](rr.Result().Body) + assert.Equal(decodedRes.Created, "test1") }) t.Run("should fail unprocessable entitiy", func(t *testing.T) { - failingUnprocessableEntitiy := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/anything", - } - decodedResp, rawResp := failingUnprocessableEntitiy.RunRequest(suite.app) - assert.Equal(http.StatusUnprocessableEntity, rawResp.StatusCode) - assert.Contains(decodedResp.Message, "Unprocessable Entity") + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/anything", http.NoBody) + assert.Equal(http.StatusUnprocessableEntity, rr.Code) }) - t.Run("should fail bad request", func(t *testing.T) { - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/1.1", - } - decodedResp, rawResp := failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawResp.StatusCode) - assert.NotZero(decodedResp.Message) + t.Run("should fail bad request params", func(t *testing.T) { + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/1.1", http.NoBody) + assert.Equal(http.StatusBadRequest, rr.Code) + }) + + t.Run("should fail bad request body", func(t *testing.T) { failingBadRequestBody, _ := utils.EncodeBody([]models.CreateTablePayload{{ ColName: "test2", Type: "varchar(255)", @@ -168,15 +128,8 @@ func (suite *TableHandlerTestSuite) TestHandleCreateTable() { Default: nil, Unique: nil, }}) - failingBadRequest = utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/anything", - ReqBody: failingBadRequestBody, - } - decodedResp, rawResp = failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawResp.StatusCode) - assert.NotZero(decodedResp.Message) - + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/anything", failingBadRequestBody) + assert.Equal(http.StatusBadRequest, rr.Code) }) t.Run("should fail internal server", func(t *testing.T) { @@ -186,20 +139,15 @@ func (suite *TableHandlerTestSuite) TestHandleCreateTable() { Default: "kareem", Unique: true, }}) - failingInternalServer := utils.RequestTesting[models.CreateTableResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test1", - ReqBody: failingInternalServerBody, - } - - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - _, rawResp := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawResp.StatusCode) + rr := utils.TestRequest(r, http.MethodPost, "/table/test1", failingInternalServerBody) + + assert.Equal(http.StatusInternalServerError, rr.Code) }) } @@ -210,66 +158,41 @@ func (suite *TableHandlerTestSuite) TestHandleAddColumn() { t.Run("should pass", func(t *testing.T) { passingBody, _ := utils.EncodeBody(models.AddModifyColumnPayload{ColName: "test3", Type: "varchar(255)"}) - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test/column/add", - ReqBody: passingBody, - } - decoedBody, _ := passing.RunRequest(suite.app) + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/test/column/add", passingBody) + decoedBody := utils.DecodeBody[models.SuccessResp](rr.Result().Body) assert.True(decoedBody.Success) - }) t.Run("should fail unproccessable entity", func(t *testing.T) { - failingUnprocessableEntitiy := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test/column/add", - } - decodedResp, rawResp := failingUnprocessableEntitiy.RunRequest(suite.app) - assert.Equal(http.StatusUnprocessableEntity, rawResp.StatusCode) - assert.Contains(decodedResp.Message, "Unprocessable Entity") - - failingBadRequestParam := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/1.1/column/add", - } - decodedRes, rawRes := failingBadRequestParam.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/test/column/add", http.NoBody) + assert.Equal(http.StatusUnprocessableEntity, rr.Code) }) - t.Run("should fail bad request", func(t *testing.T) { + t.Run("should fail bad request param", func(t *testing.T) { + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/1.1/column/add", http.NoBody) + assert.Equal(http.StatusBadRequest, rr.Code) + }) + + t.Run("should fail bad request body", func(t *testing.T) { failingBadRequestBody, _ := utils.EncodeBody(models.AddModifyColumnPayload{ ColName: "", Type: "varchar(255)", }) - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test/column/add", - ReqBody: failingBadRequestBody, - } - decodedResp, rawResp := failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawResp.StatusCode) - assert.Len(decodedResp.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodPost, "/table/test/column/add", failingBadRequestBody) + assert.Equal(http.StatusBadRequest, rr.Code) }) t.Run("should fail internal server", func(t *testing.T) { failingInternalServerBody, _ := utils.EncodeBody(models.AddModifyColumnPayload{ColName: "test3", Type: "varchar(255)"}) - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodPost, - ReqUrl: "/table/test/column/add", - ReqBody: failingInternalServerBody, - } - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() - handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - _, rawRes := passing.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawRes.StatusCode) + rr := utils.TestRequest(r, http.MethodPost, "/table/test/column/add", failingInternalServerBody) + assert.Equal(http.StatusInternalServerError, rr.Code) }) } @@ -279,33 +202,26 @@ func (suite *TableHandlerTestSuite) TestHandleUpdateColumn() { t.Run("should pass", func(t *testing.T) { passingBody, _ := utils.EncodeBody(models.AddModifyColumnPayload{ColName: "name", Type: "varchar(255)"}) - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodPut, - ReqUrl: "/table/test/column/modify", - ReqBody: passingBody, - } - decodedRes, _ := passing.RunRequest(suite.app) + rr := utils.TestRequest(suite.r, http.MethodPut, "/table/test/column/modify", passingBody) + decodedRes := utils.DecodeBody[models.SuccessResp](rr.Result().Body) + assert.Equal(rr.Code, http.StatusOK) assert.True(decodedRes.Success) }) t.Run("should fail unproccessable entity", func(t *testing.T) { - failingUnprocessableEntity := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPut, - ReqUrl: "/table/test/column/modify", - } - decodedRes, rawRes := failingUnprocessableEntity.RunRequest(suite.app) - assert.Equal(http.StatusUnprocessableEntity, rawRes.StatusCode) - assert.Contains(decodedRes.Message, "Unprocessable Entity") + rr := utils.TestRequest(suite.r, http.MethodPut, "/table/test/column/modify", http.NoBody) + assert.Equal(http.StatusUnprocessableEntity, rr.Code) + + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail bad request param", func(t *testing.T) { - failingBadRequestParam := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPut, - ReqUrl: "/table/1.1/column/modify", - } - decodedRes, rawRes := failingBadRequestParam.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodPut, "/table/1.1/column/modify", http.NoBody) + assert.Equal(http.StatusBadRequest, rr.Code) + + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail bad request body", func(t *testing.T) { @@ -313,31 +229,24 @@ func (suite *TableHandlerTestSuite) TestHandleUpdateColumn() { ColName: "", Type: "varchar(255)", }) - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodPut, - ReqUrl: "/table/test/column/modify", - ReqBody: failingBadRequestBody, - } - decodedRes, rawRes := failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodPut, "/table/test/column/modify", failingBadRequestBody) + assert.Equal(http.StatusBadRequest, rr.Code) + + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail internal server", func(t *testing.T) { failingInternalServerBody, _ := utils.EncodeBody(models.AddModifyColumnPayload{ColName: "name", Type: "varchar(255)"}) - failingInternalServer := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodPut, - ReqUrl: "/table/test/column/modify", - ReqBody: failingInternalServerBody, - } - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) - _, rawRes := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawRes.StatusCode) + rr := utils.TestRequest(r, http.MethodPut, "/table/test/column/modify", failingInternalServerBody) + + assert.Equal(http.StatusInternalServerError, rr.Code) }) } @@ -347,65 +256,49 @@ func (suite *TableHandlerTestSuite) TestHandleDeleteColumn() { t.Run("should pass", func(t *testing.T) { passingBody, _ := utils.EncodeBody(models.DeleteColumnPayload{ColName: "name"}) - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test/column/delete", - ReqBody: passingBody, - } - - decoedRes, _ := passing.RunRequest(suite.app) + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/test/column/delete", passingBody) - assert.True(decoedRes.Success) + decodedRes := utils.DecodeBody[models.SuccessResp](rr.Result().Body) + assert.True(decodedRes.Success) }) t.Run("should fail bad request param", func(t *testing.T) { - failingBadRequestParam := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/1.1/column/delete", - } - decodedRes, rawRes := failingBadRequestParam.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/1.1/column/delete", http.NoBody) + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + + assert.Equal(http.StatusBadRequest, rr.Code) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail bad request", func(t *testing.T) { failingBadRequestBody, _ := utils.EncodeBody(models.DeleteColumnPayload{ ColName: "", }) - failingBadRequest := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test/column/delete", - ReqBody: failingBadRequestBody, - } - decodedRes, rawRes := failingBadRequest.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/test/column/delete", failingBadRequestBody) + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + + assert.Equal(http.StatusBadRequest, rr.Code) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail unproccessable entity", func(t *testing.T) { - failingUnprocessableEntity := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test/column/delete", - } - decodedRes, rawRes := failingUnprocessableEntity.RunRequest(suite.app) - assert.Equal(http.StatusUnprocessableEntity, rawRes.StatusCode) - assert.Contains(decodedRes.Message, "Unprocessable Entity") + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/test/column/delete", http.NoBody) + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + + assert.Equal(http.StatusUnprocessableEntity, rr.Code) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail internal server", func(t *testing.T) { failingInternalServerBody, _ := utils.EncodeBody(models.DeleteColumnPayload{ColName: "name"}) - failingInternalServer := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test/column/delete", - ReqBody: failingInternalServerBody, - } - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) + + rr := utils.TestRequest(r, http.MethodDelete, "/table/test/column/delete", failingInternalServerBody) - _, rawRes := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawRes.StatusCode) + assert.Equal(http.StatusInternalServerError, rr.Code) }) } @@ -414,37 +307,29 @@ func (suite *TableHandlerTestSuite) TestHandleDeleteTable() { t := suite.T() t.Run("should pass", func(t *testing.T) { - passing := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test", - } - decodedRes, _ := passing.RunRequest(suite.app) + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/test", http.NoBody) + decodedRes := utils.DecodeBody[models.SuccessResp](rr.Result().Body) + assert.True(decodedRes.Success) }) t.Run("should fail bad request param", func(t *testing.T) { - failingBadRequestParams := utils.RequestTesting[models.ErrResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/1.1", - } - decodedRes, rawRes := failingBadRequestParams.RunRequest(suite.app) - assert.Equal(http.StatusBadRequest, rawRes.StatusCode) - assert.Len(decodedRes.Message, 1) + rr := utils.TestRequest(suite.r, http.MethodDelete, "/table/1.1", http.NoBody) + decodedRes := utils.DecodeBody[models.ErrResp](rr.Result().Body) + + assert.Equal(http.StatusBadRequest, rr.Code) + assert.NotEmpty(decodedRes.Message) }) t.Run("should fail internal server", func(t *testing.T) { - failingInternalServer := utils.RequestTesting[models.SuccessResp]{ - ReqMethod: http.MethodDelete, - ReqUrl: "/table/test", - } - - app := fiber.New() + r := chi.NewRouter() storage := NewFaultyTableExecutor() handler := NewTableHandler(storage) - handler.RegisterRoutes(app) + handler.RegisterRoutes(r) + + rr := utils.TestRequest(r, http.MethodDelete, "/table/test", http.NoBody) - _, rawRes := failingInternalServer.RunRequest(app) - assert.Equal(http.StatusInternalServerError, rawRes.StatusCode) + assert.Equal(http.StatusInternalServerError, rr.Code) }) } diff --git a/internal/server.go b/internal/server.go index 87a18eb..a97d666 100644 --- a/internal/server.go +++ b/internal/server.go @@ -2,73 +2,56 @@ package internal import ( "fmt" + "log/slog" + "net/http" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/playground" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/kareemmahlees/meta-x/internal/db" "github.com/kareemmahlees/meta-x/internal/graph" "github.com/kareemmahlees/meta-x/internal/handlers" - "github.com/kareemmahlees/meta-x/utils" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/swagger" + httpSwagger "github.com/swaggo/http-swagger" ) type Server struct { storage db.Storage port int + router *chi.Mux listenCh chan bool - app *fiber.App } func NewServer(storage db.Storage, port int, listenCh chan bool) *Server { - return &Server{storage, port, listenCh, nil} + r := chi.NewRouter() + return &Server{storage, port, r, listenCh} } func (s *Server) Serve() error { - // see https://github.com/99designs/gqlgen/issues/1664#issuecomment-1616620967 - // Create a gqlgen handler h := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{Storage: s.storage}})) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) - s.app = app - - app.All("/graphql", func(c *fiber.Ctx) error { - utils.GraphQLHandler(h.ServeHTTP)(c) - return nil - }).Name("graphql") - - app.All("/playground", func(c *fiber.Ctx) error { - utils.GraphQLHandler(playground.Handler("GraphQL", "/graphql"))(c) - return nil - }).Name("playground") + s.router.Use(middleware.Logger) + s.router.Post("/graphql", h.ServeHTTP) + s.router.Get("/playground", playground.ApolloSandboxHandler("GraphQL", "/graphql")) + s.router.Get("/swagger/*", httpSwagger.Handler( + httpSwagger.URL(fmt.Sprintf("http://localhost:%d/swagger/doc.json", s.port)), + )) - app.Get("/swagger/*", swagger.HandlerDefault).Name("swagger") - app.Use(logger.New()) - - defaultHandler := handlers.NewDefaultHandler(nil) + defaultHandler := handlers.NewDefaultHandler() dbHandler := handlers.NewDBHandler(s.storage) tableHandler := handlers.NewTableHandler(s.storage) - defaultHandler.RegisterRoutes(app) - dbHandler.RegisterRoutes(app) - tableHandler.RegisterRoutes(app) + defaultHandler.RegisterRoutes(s.router) + dbHandler.RegisterRoutes(s.router) + tableHandler.RegisterRoutes(s.router) - app.Hooks().OnListen(func(ld fiber.ListenData) error { - s.listenCh <- true - fmt.Println(utils.NewStyle("REST", "#4B87FF"), fmt.Sprintf("http://localhost:%d", s.port)) - fmt.Println(utils.NewStyle("Swagger", "#0EEBA1"), fmt.Sprintf("http://localhost:%d/swagger", s.port)) - fmt.Println(utils.NewStyle("GraphQl", "#FF70FD"), fmt.Sprintf("http://localhost:%d/graphql", s.port)) - fmt.Println(utils.NewStyle("Playground", "#B6B5B5"), fmt.Sprintf("http://localhost:%d/playground\n", s.port)) + slog.Info("Server started listening", "port", s.port) - return nil - }) + s.listenCh <- true - if err := app.Listen(fmt.Sprintf(":%d", s.port)); err != nil { + if err := http.ListenAndServe(fmt.Sprintf(":%d", s.port), s.router); err != nil { s.listenCh <- false return err } return nil - } diff --git a/internal/server_test.go b/internal/server_test.go index d47d13a..5d0cba4 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -1,13 +1,10 @@ package internal import ( - "fmt" "log" - "net/http" "testing" "github.com/kareemmahlees/meta-x/models" - "github.com/kareemmahlees/meta-x/utils" "github.com/stretchr/testify/assert" ) @@ -47,6 +44,7 @@ func (ms *MockStorage) DeleteColumn(tableName string, data models.DeleteColumnPa func TestServe(t *testing.T) { listenCh := make(chan bool, 1) + server := NewServer(NewMockStorage(), 5522, listenCh) go func() { @@ -57,19 +55,15 @@ func TestServe(t *testing.T) { assert.True(t, <-listenCh) - testRoutes := []string{"graphql", "playground", "swagger"} - - for _, route := range testRoutes { - foundRoute := server.app.GetRoute(route) - assert.NotEmpty(t, foundRoute) - - request := utils.RequestTesting[any]{ - ReqMethod: http.MethodGet, - ReqUrl: fmt.Sprintf("/%s", route), - } - _, res := request.RunRequest(server.app) + testRoutes := []string{"/graphql", "/playground", "/swagger/*"} + registerdRoutes := []string{} - assert.NotEqual(t, http.StatusNotFound, res.StatusCode) + for _, route := range server.router.Routes() { + registerdRoutes = append(registerdRoutes, route.Pattern) + } + for _, route := range testRoutes { + assert.Contains(t, registerdRoutes, route) } + } diff --git a/lib/errors.go b/lib/errors.go deleted file mode 100644 index a2ab47e..0000000 --- a/lib/errors.go +++ /dev/null @@ -1,35 +0,0 @@ -package lib - -import ( - "github.com/kareemmahlees/meta-x/models" - - "github.com/gofiber/fiber/v2" -) - -func BadRequestErr(c *fiber.Ctx, errMsg any) error { - return c.Status(fiber.StatusBadRequest).JSON(models.ErrResp{ - Code: fiber.StatusBadRequest, - Message: errMsg, - }) -} - -func UnprocessableEntityErr(c *fiber.Ctx, errMsg any) error { - return c.Status(fiber.StatusUnprocessableEntity).JSON(models.ErrResp{ - Code: fiber.StatusUnprocessableEntity, - Message: errMsg, - }) -} - -func ForbiddenErr(c *fiber.Ctx, errMsg any) error { - return c.Status(fiber.StatusForbidden).JSON(models.ErrResp{ - Code: fiber.StatusForbidden, - Message: errMsg, - }) -} - -func InternalServerErr(c *fiber.Ctx, errMsg any) error { - return c.Status(fiber.StatusInternalServerError).JSON(models.ErrResp{ - Code: fiber.StatusInternalServerError, - Message: errMsg, - }) -} diff --git a/lib/errors_test.go b/lib/errors_test.go deleted file mode 100644 index 219b389..0000000 --- a/lib/errors_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package lib - -import ( - "testing" - - "github.com/gofiber/fiber/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "github.com/valyala/fasthttp" -) - -type ErrorTestSuite struct { - suite.Suite - fiberCtx *fiber.Ctx -} - -func (suite *ErrorTestSuite) SetupSuite() { - app := fiber.New() - suite.fiberCtx = app.AcquireCtx(&fasthttp.RequestCtx{}) -} - -func (suite *ErrorTestSuite) TestBadRequestErr() { - t := suite.T() - - err := BadRequestErr(suite.fiberCtx, "anything") - assert.Nil(t, err) - - err = BadRequestErr(suite.fiberCtx, make(chan any)) - assert.NotNil(t, err) -} - -func (suite *ErrorTestSuite) TestUnprocessableEntityErr() { - t := suite.T() - - err := UnprocessableEntityErr(suite.fiberCtx, "anything") - assert.Nil(t, err) - - err = UnprocessableEntityErr(suite.fiberCtx, make(chan any)) - assert.NotNil(t, err) -} - -func (suite *ErrorTestSuite) TestForbiddenErr() { - t := suite.T() - - err := ForbiddenErr(suite.fiberCtx, "anything") - assert.Nil(t, err) - - err = ForbiddenErr(suite.fiberCtx, make(chan any)) - assert.NotNil(t, err) -} -func (suite *ErrorTestSuite) TestInternalServerErr() { - t := suite.T() - - err := InternalServerErr(suite.fiberCtx, "anything") - assert.Nil(t, err) - - err = InternalServerErr(suite.fiberCtx, make(chan any)) - assert.NotNil(t, err) -} - -func TestErrorsTestSuite(t *testing.T) { - suite.Run(t, new(ErrorTestSuite)) -} diff --git a/utils/testHelpers.go b/utils/testHelpers.go index edffe66..248ec06 100644 --- a/utils/testHelpers.go +++ b/utils/testHelpers.go @@ -4,13 +4,12 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" "time" - "github.com/gofiber/fiber/v2" + "github.com/go-chi/chi/v5" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" @@ -96,30 +95,9 @@ func SliceOfPointersToSliceOfValues[T any](s []*T) []T { return v } -type FiberRoute struct { - Method string - Path string -} - -// Struct for aiding the process of testing routes. -// Header is set by default to "Content-Type": "application/json" -type RequestTesting[T any] struct { - ReqMethod string - ReqUrl string // relative to the base url which is "http://localhost:5522" - ReqBody io.Reader - Res *http.Response - ResBody T -} - -// Runs a request and returns the decoded form [T] and the raw form [*http.Response] -func (rt *RequestTesting[T]) RunRequest(app *fiber.App) (T, *http.Response) { - req := httptest.NewRequest(rt.ReqMethod, fmt.Sprintf("http://localhost:5522%s", rt.ReqUrl), rt.ReqBody) - if rt.ReqBody != nil { - req.Header.Set("Content-Type", "application/json") - } - resp, _ := app.Test(req) - - resBody := DecodeBody[T](resp.Body) - rt.ResBody = resBody - return resBody, resp +func TestRequest(r *chi.Mux, method, url string, body io.Reader) *httptest.ResponseRecorder { + req, _ := http.NewRequest(method, url, body) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + return rr } diff --git a/utils/testHelpers_test.go b/utils/testHelpers_test.go index 88afba2..0a31f89 100644 --- a/utils/testHelpers_test.go +++ b/utils/testHelpers_test.go @@ -8,10 +8,11 @@ import ( "reflect" "strings" "testing" + "time" + "github.com/go-chi/chi/v5" "github.com/kareemmahlees/meta-x/lib" - "github.com/gofiber/fiber/v2" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" @@ -20,36 +21,63 @@ import ( ) func TestCreatePostgresContainer(t *testing.T) { - ctx := context.Background() - pgContainer, err := CreatePostgresContainer(ctx) - defer func() { - _ = pgContainer.Terminate(ctx) - }() + t.Run("should pass", func(t *testing.T) { + ctx := context.Background() + pgContainer, err := CreatePostgresContainer(ctx) + defer func() { + _ = pgContainer.Terminate(ctx) + }() - assert.Nil(t, err) + assert.Nil(t, err) - con, err := sqlx.Open(lib.PSQL, pgContainer.ConnectionString) - assert.Nil(t, err) + con, err := sqlx.Open(lib.PSQL, pgContainer.ConnectionString) + assert.Nil(t, err) - defer con.Close() + defer con.Close() - err = con.Ping() - assert.Nil(t, err) + err = con.Ping() + assert.Nil(t, err) + }) + + t.Run("should fail timeout exceeded", func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + + _, err := CreatePostgresContainer(ctx) + + assert.Error(t, err) + + }) } func TestCreateMySQLContainer(t *testing.T) { - ctx := context.Background() - mysqlContainer, err := CreateMySQLContainer(ctx) - defer func() { - _ = mysqlContainer.Terminate(ctx) - }() + t.Run("should pass", func(t *testing.T) { + ctx := context.Background() + mysqlContainer, err := CreateMySQLContainer(ctx) + defer func() { + _ = mysqlContainer.Terminate(ctx) + }() - assert.Nil(t, err) + assert.Nil(t, err) - con, err := sqlx.Open(lib.MYSQL, mysqlContainer.ConnectionString) - assert.Nil(t, err) + con, err := sqlx.Open(lib.MYSQL, mysqlContainer.ConnectionString) + assert.Nil(t, err) + + defer con.Close() - defer con.Close() + }) + + t.Run("should fail timetout exceeded", func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + + _, err := CreateMySQLContainer(ctx) + + assert.Error(t, err) + + }) } func TestEncodeBody(t *testing.T) { @@ -128,45 +156,15 @@ func TestSliceOfPointersToSliceOfValues(t *testing.T) { assert.IsType(t, reflect.SliceOf(reflect.TypeOf("")), reflect.TypeOf(soptsov)) } -func TestRunRequest(t *testing.T) { - app := fiber.New() - app.Get("/health", func(c *fiber.Ctx) error { - return c.Status(fiber.StatusOK).JSON(fiber.Map{"date": "fake_date"}) - }) - mockReq1 := RequestTesting[struct { - Date string `json:"date"` - }]{ - ReqMethod: http.MethodGet, - ReqUrl: "/health", - } - decodedRes, rawRes := mockReq1.RunRequest(app) - assert.Equal(t, http.StatusOK, rawRes.StatusCode) - assert.NotEmpty(t, decodedRes.Date) - - type mockPayload struct { - Name string `json:"name"` - } - - app.Post("/test", func(c *fiber.Ctx) error { - var payload mockPayload - if err := c.BodyParser(&payload); err != nil { - return c.Status(fiber.StatusUnprocessableEntity).JSON(fiber.Map{}) - } - return nil +func TestRequestTest(t *testing.T) { + r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello")) }) - mockBody, _ := EncodeBody(mockPayload{Name: "any"}) - mockReq2 := RequestTesting[any]{ - ReqMethod: http.MethodPost, - ReqUrl: "/test", - ReqBody: mockBody, - } - _, rawResponse := mockReq2.RunRequest(app) - assert.NotEqual(t, http.StatusUnprocessableEntity, rawResponse.StatusCode) - - mockReq3 := RequestTesting[any]{ - ReqMethod: http.MethodPost, - ReqUrl: "/test", - } - _, rawResponse = mockReq3.RunRequest(app) - assert.Equal(t, http.StatusUnprocessableEntity, rawResponse.StatusCode) + + rr := TestRequest(r, http.MethodGet, "/", http.NoBody) + + assert.Equal(t, rr.Code, http.StatusOK) + assert.Equal(t, rr.Body.String(), "Hello") }