diff --git a/.gitignore b/.gitignore index daf913b..bcd81c4 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ _testmain.go *.exe *.test *.prof + +swagger.json diff --git a/example/widgets/routes.go b/example/widgets/routes.go index d3633c6..0c3ad46 100644 --- a/example/widgets/routes.go +++ b/example/widgets/routes.go @@ -14,7 +14,7 @@ var Routes = []crud.Spec{{ Tags: tags, Validate: crud.Validate{ Query: map[string]crud.Field{ - "limit": crud.Number(), + "limit": crud.Number().Required().Min(0).Max(25).Description("Records to return"), }, }, }, { diff --git a/prehandler.go b/prehandler.go new file mode 100644 index 0000000..dd62198 --- /dev/null +++ b/prehandler.go @@ -0,0 +1,99 @@ +package crud + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io/ioutil" + "net/http" + "strconv" +) + +// this is where the validation happens! +func preHandler(spec Spec) gin.HandlerFunc { + return func(c *gin.Context) { + val := spec.Validate + if val.Query != nil { + values := c.Request.URL.Query() + for field, schema := range val.Query { + // query values are always strings, so we must try to convert + queryValue := values.Get(field) + + // don't try to convert if the field is empty + if queryValue == "" { + if schema.required != nil && *schema.required { + c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrRequired)) + } + return + } + var convertedValue interface{} + switch schema.kind { + case KindBoolean: + if queryValue == "true" { + convertedValue = true + } else if queryValue == "false" { + convertedValue = false + } else { + c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) + return + } + case KindString: + convertedValue = queryValue + case KindNumber: + var err error + convertedValue, err = strconv.ParseFloat(queryValue, 64) + if err != nil { + c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) + return + } + case KindInteger: + var err error + convertedValue, err = strconv.Atoi(queryValue) + if err != nil { + c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) + return + } + case KindArray: + // TODO I'm not sure how this works yet + c.AbortWithStatusJSON(http.StatusNotImplemented, "TODO") + return + default: + c.AbortWithStatusJSON(400, fmt.Sprintf("Validation not possible due to kind: %v", schema.kind)) + } + if err := schema.Validate(convertedValue); err != nil { + c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, err.Error())) + return + } + } + } + + if val.Body != nil { + // TODO this could be an array, basic type, or "null" + var body map[string]interface{} + if err := c.BindJSON(&body); err != nil { + c.AbortWithStatusJSON(400, err.Error()) + return + } + for field, schema := range val.Body { + if err := schema.Validate(body[field]); err != nil { + c.AbortWithStatusJSON(400, fmt.Sprintf("Body validation failed for field %v: %v", field, err.Error())) + return + } + } + // TODO perhaps the user passes a struct to bind to instead? + data, _ := json.Marshal(body) + c.Request.Body = ioutil.NopCloser(bytes.NewReader(data)) + } + + if val.Path != nil { + for field, schema := range val.Path { + path := c.Param(field) + if schema.required != nil && *schema.required && path == "" { + c.AbortWithStatusJSON(400, fmt.Sprintf("Missing path param")) + return + } + } + } + } +} diff --git a/prehandler_test.go b/prehandler_test.go new file mode 100644 index 0000000..bb52b8a --- /dev/null +++ b/prehandler_test.go @@ -0,0 +1,154 @@ +package crud + +import ( + "github.com/gin-gonic/gin" + "net/http/httptest" + "testing" +) + +func init() { + gin.SetMode(gin.ReleaseMode) +} + +func query(query string) (*httptest.ResponseRecorder, *gin.Context) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "http://example.com"+query, nil) + return w, c +} + +func TestQueryValidation(t *testing.T) { + tests := []struct { + Schema map[string]Field + Input string + Expected int + }{ + { + Schema: map[string]Field{ + "testquery": String(), + }, + Input: "", + Expected: 200, + }, { + Schema: map[string]Field{ + "testquery": String().Required(), + }, + Input: "", + Expected: 400, + }, { + Schema: map[string]Field{ + "testquery": String().Required(), + }, + Input: "?testquery=", + Expected: 400, + }, { + Schema: map[string]Field{ + "testquery": String().Required(), + }, + Input: "?testquery=ok", + Expected: 200, + }, { + Schema: map[string]Field{ + "testquery": Number(), + }, + Input: "", + Expected: 200, + }, { + Schema: map[string]Field{ + "testquery": Number().Required(), + }, + Input: "", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Number().Required(), + }, + Input: "?testquery=1", + Expected: 200, + }, + { + Schema: map[string]Field{ + "testquery": Number().Required(), + }, + Input: "?testquery=1.1", + Expected: 200, + }, + { + Schema: map[string]Field{ + "testquery": Number(), + }, + Input: "?testquery=a", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Boolean(), + }, + Input: "?testquery=true", + Expected: 200, + }, + { + Schema: map[string]Field{ + "testquery": Boolean(), + }, + Input: "?testquery=false", + Expected: 200, + }, + { + Schema: map[string]Field{ + "testquery": Boolean(), + }, + Input: "?testquery=1", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Integer(), + }, + Input: "?testquery=1", + Expected: 200, + }, + { + Schema: map[string]Field{ + "testquery": Integer().Max(1), + }, + Input: "?testquery=2", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Integer().Min(5), + }, + Input: "?testquery=4", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Integer(), + }, + Input: "?testquery=1.1", + Expected: 400, + }, + { + Schema: map[string]Field{ + "testquery": Integer(), + }, + Input: "?testquery=a", + Expected: 400, + }, + } + + for _, test := range tests { + handler := preHandler(Spec{ + Validate: Validate{Query: test.Schema}, + }) + + w, c := query(test.Input) + handler(c) + + if w.Result().StatusCode != test.Expected { + t.Errorf("expected '%v' got '%v'. input: '%v'. schema: '%v'", test.Expected, w.Code, test.Input, test.Schema) + } + } +} diff --git a/router.go b/router.go index f98c640..3ef44e5 100644 --- a/router.go +++ b/router.go @@ -1,136 +1,43 @@ package crud import ( - "bytes" _ "embed" - "encoding/json" "fmt" "github.com/gin-gonic/gin" - "io/ioutil" - "net/http" "regexp" - "strconv" "strings" ) +// Router is the main object that is used to generate swagger and holds the underlying router. type Router struct { - Swagger string `json:"swagger"` - Info Info `json:"info"` - Paths map[string]*Path `json:"paths"` - Definitions map[string]JsonSchema `json:"definitions"` + // Swagger is exposed so the user can edit additional optional fields. + Swagger Swagger - Specs []Spec `json:"-"` - Mux *gin.Engine `json:"-"` -} - -type Info struct { - Title string `json:"title"` - Version string `json:"version"` -} - -type JsonSchema struct { - Type string `json:"type,omitempty"` - Properties map[string]JsonSchema `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` - Example interface{} `json:"example,omitempty"` - Description string `json:"description,omitempty"` - Minimum float64 `json:"minimum,omitempty"` - Maximum float64 `json:"maximum,omitempty"` - Enum []interface{} `json:"enum,omitempty"` -} - -type Path struct { - Get *Operation `json:"get,omitempty"` - Post *Operation `json:"post,omitempty"` - Put *Operation `json:"put,omitempty"` - Delete *Operation `json:"delete,omitempty"` - Patch *Operation `json:"patch,omitempty"` -} - -type Operation struct { - Tags []string `json:"tags,omitempty"` - Parameters []Parameter `json:"parameters,omitempty"` - Responses map[string]Response `json:"responses"` - Description string `json:"description"` - Summary string `json:"summary"` -} - -type Parameter struct { - In string `json:"in"` - Name string `json:"name"` - - Type string `json:"type,omitempty"` - Schema *Ref `json:"schema,omitempty"` + // Mux is the underlying router being used. The user can add middlewares and use other features. + Mux *gin.Engine - Required *bool `json:"required,omitempty"` - Description string `json:"description,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - Enum []interface{} `json:"enum,omitempty"` -} - -type Ref struct { - Ref string `json:"$ref,omitempty"` -} - -type Response struct { - Schema JsonSchema `json:"schema"` - Description string `json:"description"` -} - -var DefaultResponse = map[string]Response{ - "default": { - Schema: JsonSchema{Type: "string"}, - Description: "Successful", - }, + // used for automatically incrementing the model name, e.g. Model 1, Model 2. + modelCounter int } +// NewRouter initializes a router. func NewRouter(title, version string) *Router { return &Router{ - Swagger: "2.0", - Info: Info{ - Title: title, - Version: version, + Swagger: Swagger{ + Swagger: "2.0", + Info: Info{Title: title, Version: version}, + Paths: map[string]*Path{}, + Definitions: map[string]JsonSchema{}, }, - Mux: gin.Default(), + Mux: gin.Default(), + modelCounter: 1, } } +// Add routes to the swagger spec and installs a handler with built-in validation. func (r *Router) Add(specs ...Spec) { - r.Specs = append(r.Specs, specs...) -} - -type Spec struct { - Method string - Path string - PreHandlers []gin.HandlerFunc - Handler gin.HandlerFunc - Description string - Tags []string - Summary string - - Validate Validate -} - -type Validate struct { - Query map[string]Field - Body map[string]Field - Path map[string]Field - FormData map[string]Field - Header map[string]Field -} - -func (r *Router) Use(middlewares ...gin.HandlerFunc) { - r.Mux.Use(middlewares...) -} - -func (r *Router) Serve(addr string) error { - modelCounter := 1 - r.Definitions = map[string]JsonSchema{} - - r.Paths = map[string]*Path{} - for i := range r.Specs { - spec := r.Specs[i] + for i := range specs { + spec := specs[i] handlers := []gin.HandlerFunc{preHandler(spec)} handlers = append(handlers, spec.PreHandlers...) @@ -138,30 +45,31 @@ func (r *Router) Serve(addr string) error { r.Mux.Handle(spec.Method, swaggerToGinPattern(spec.Path), handlers...) - if _, ok := r.Paths[spec.Path]; !ok { - r.Paths[spec.Path] = &Path{} + if _, ok := r.Swagger.Paths[spec.Path]; !ok { + r.Swagger.Paths[spec.Path] = &Path{} } - path := r.Paths[spec.Path] + path := r.Swagger.Paths[spec.Path] var operation *Operation switch strings.ToLower(spec.Method) { case "get": - path.Get = &Operation{Responses: DefaultResponse} + path.Get = &Operation{} operation = path.Get case "post": - path.Post = &Operation{Responses: DefaultResponse} + path.Post = &Operation{} operation = path.Post case "put": - path.Put = &Operation{Responses: DefaultResponse} + path.Put = &Operation{} operation = path.Put case "patch": - path.Patch = &Operation{Responses: DefaultResponse} + path.Patch = &Operation{} operation = path.Patch case "delete": - path.Delete = &Operation{Responses: DefaultResponse} + path.Delete = &Operation{} operation = path.Delete default: panic("Unhandled method " + spec.Method) } + operation.Responses = DefaultResponse operation.Tags = spec.Tags operation.Description = spec.Description operation.Summary = spec.Summary @@ -197,20 +105,46 @@ func (r *Router) Serve(addr string) error { } } if spec.Validate.Body != nil { - modelName := fmt.Sprintf("Model %v", modelCounter) + modelName := fmt.Sprintf("Model %v", r.modelCounter) parameter := Parameter{ In: "body", Name: "body", Schema: &Ref{fmt.Sprint("#/definitions/", modelName)}, } - r.Definitions[modelName] = ToJsonSchema(spec.Validate.Body) - modelCounter++ + r.Swagger.Definitions[modelName] = ToJsonSchema(spec.Validate.Body) + r.modelCounter++ operation.Parameters = append(operation.Parameters, parameter) } } +} + +// Spec is used to generate swagger paths and automatic handler validation. +type Spec struct { + Method string + Path string + PreHandlers []gin.HandlerFunc + Handler gin.HandlerFunc + Description string + Tags []string + Summary string + Validate Validate +} + +// Validate are optional fields that will be used during validation. Leave unneeded +// properties nil and they will be ignored. +type Validate struct { + Query map[string]Field + Body map[string]Field + Path map[string]Field + FormData map[string]Field + Header map[string]Field +} + +// Serve installs the swagger and the swagger-ui and runs the server. +func (r *Router) Serve(addr string) error { r.Mux.GET("/swagger.json", func(c *gin.Context) { - c.JSON(200, r) + c.JSON(200, r.Swagger) }) r.Mux.GET("/", func(c *gin.Context) { @@ -225,92 +159,7 @@ func (r *Router) Serve(addr string) error { return err } -func preHandler(spec Spec) gin.HandlerFunc { - return func(c *gin.Context) { - val := spec.Validate - if val.Query != nil { - values := c.Request.URL.Query() - for field, schema := range val.Query { - // query values are always strings, so we must try to convert - queryValue := values.Get(field) - - // don't try to convert if the field is empty - if queryValue == "" { - if schema.required != nil && *schema.required { - c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrRequired)) - } - return - } - var convertedValue interface{} - switch schema.kind { - case KindBoolean: - if queryValue == "true" { - convertedValue = true - } else if queryValue == "false" { - convertedValue = false - } else { - c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) - return - } - case KindString: - convertedValue = queryValue - case KindNumber: - var err error - convertedValue, err = strconv.ParseFloat(queryValue, 64) - if err != nil { - c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) - return - } - case KindInteger: - var err error - convertedValue, err = strconv.Atoi(queryValue) - if err != nil { - c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, ErrWrongType)) - return - } - case KindArray: - // TODO I'm not sure how this works yet - c.AbortWithStatusJSON(http.StatusNotImplemented, "TODO") - return - default: - c.AbortWithStatusJSON(400, fmt.Sprintf("Validation not possible due to kind: %v", schema.kind)) - } - if err := schema.Validate(convertedValue); err != nil { - c.AbortWithStatusJSON(400, fmt.Sprintf("Query validation failed for field %v: %v", field, err.Error())) - return - } - } - } - - if val.Body != nil { - var body map[string]interface{} - if err := c.BindJSON(&body); err != nil { - c.AbortWithStatusJSON(400, err.Error()) - return - } - for field, schema := range val.Body { - if err := schema.Validate(body[field]); err != nil { - c.AbortWithStatusJSON(400, fmt.Sprintf("Body validation failed for field %v: %v", field, err.Error())) - return - } - } - // TODO perhaps the user passes a struct to bind to instead? - data, _ := json.Marshal(body) - c.Request.Body = ioutil.NopCloser(bytes.NewReader(data)) - } - - if val.Path != nil { - for field, schema := range val.Path { - path := c.Param(field) - if schema.required != nil && *schema.required && path == "" { - c.AbortWithStatusJSON(400, fmt.Sprintf("Missing path param")) - return - } - } - } - } -} - +// we need to convert swagger endpoints /widget/{id} to gin endpoints /widget/:id var swaggerPathPattern = regexp.MustCompile("\\{([^}]+)\\}") func swaggerToGinPattern(ginUrl string) string { diff --git a/schema.go b/schema.go index 20fdae3..9031b10 100644 --- a/schema.go +++ b/schema.go @@ -31,7 +31,7 @@ func (f *Field) Validate(value interface{}) error { switch v := value.(type) { case int: - if f.kind != "number" { + if f.kind != "integer" { return ErrWrongType } if f.max != nil && float64(v) > *f.max { diff --git a/swagger.go b/swagger.go new file mode 100644 index 0000000..488a644 --- /dev/null +++ b/swagger.go @@ -0,0 +1,71 @@ +package crud + +type Swagger struct { + Swagger string `json:"swagger"` + Info Info `json:"info"` + + Paths map[string]*Path `json:"paths"` + Definitions map[string]JsonSchema `json:"definitions"` +} + +type Info struct { + Title string `json:"title"` + Version string `json:"version"` +} + +type JsonSchema struct { + Type string `json:"type,omitempty"` + Properties map[string]JsonSchema `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + Example interface{} `json:"example,omitempty"` + Description string `json:"description,omitempty"` + Minimum float64 `json:"minimum,omitempty"` + Maximum float64 `json:"maximum,omitempty"` + Enum []interface{} `json:"enum,omitempty"` +} + +type Path struct { + Get *Operation `json:"get,omitempty"` + Post *Operation `json:"post,omitempty"` + Put *Operation `json:"put,omitempty"` + Delete *Operation `json:"delete,omitempty"` + Patch *Operation `json:"patch,omitempty"` +} + +type Operation struct { + Tags []string `json:"tags,omitempty"` + Parameters []Parameter `json:"parameters,omitempty"` + Responses map[string]Response `json:"responses"` + Description string `json:"description"` + Summary string `json:"summary"` +} + +type Parameter struct { + In string `json:"in"` + Name string `json:"name"` + + Type string `json:"type,omitempty"` + Schema *Ref `json:"schema,omitempty"` + + Required *bool `json:"required,omitempty"` + Description string `json:"description,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + Enum []interface{} `json:"enum,omitempty"` +} + +type Ref struct { + Ref string `json:"$ref,omitempty"` +} + +type Response struct { + Schema JsonSchema `json:"schema"` + Description string `json:"description"` +} + +var DefaultResponse = map[string]Response{ + "default": { + Schema: JsonSchema{Type: "string"}, + Description: "Successful", + }, +}