From 7781a3ce155ce90b8128aedfd0e966e0ec3464da Mon Sep 17 00:00:00 2001 From: limpo1989 Date: Wed, 6 Dec 2023 19:50:36 +0800 Subject: [PATCH] Ignore non-existent parameters when binding parameters --- web/bind.go | 41 ++++++++++++++++++++++++++----- web/binding/binding.go | 31 ++++++++++++------------ web/binding/binding_test.go | 20 +++++++++------- web/context.go | 39 ++++++++++++++++++++---------- web/examples/greeting/main.go | 1 - web/options.go | 17 ++++++++++++- web/server.go | 45 ++++------------------------------- 7 files changed, 110 insertions(+), 84 deletions(-) diff --git a/web/bind.go b/web/bind.go index 0b252244..95526c48 100644 --- a/web/bind.go +++ b/web/bind.go @@ -17,6 +17,7 @@ package web import ( + "errors" "fmt" "net/http" "reflect" @@ -165,21 +166,21 @@ func validMappingFunc(fnType reflect.Type) error { } if fnType.NumIn() < 1 || fnType.NumIn() > 2 { - return fmt.Errorf("%s: invalid input parameter count", fnType.String()) + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [R, error]", fnType.String()) } if fnType.NumOut() > 2 { - return fmt.Errorf("%s: invalid output parameter count", fnType.String()) + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)]", fnType.String()) } if !utils.IsContextType(fnType.In(0)) { - return fmt.Errorf("%s: first input param type (%s) must be context", fnType.String(), fnType.In(0).String()) + return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)", fnType.String()) } if fnType.NumIn() > 1 { argType := fnType.In(1) if !(reflect.Struct == argType.Kind() || (reflect.Ptr == argType.Kind() && reflect.Struct == argType.Elem().Kind())) { - return fmt.Errorf("%s: second input param type (%s) must be struct/*struct", fnType.String(), argType.String()) + return fmt.Errorf("%s: input param type (%s) must be struct/*struct", fnType.String(), argType.String()) } } @@ -188,11 +189,11 @@ func validMappingFunc(fnType reflect.Type) error { case 1: // R | error case 2: // (R, error) if utils.IsErrorType(fnType.Out(0)) { - return fmt.Errorf("%s: first output param type not be error", fnType.String()) + return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String()) } if !utils.IsErrorType(fnType.Out(1)) { - return fmt.Errorf("%s: second output type (%s) must a error", fnType.String(), fnType.Out(1).String()) + return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String()) } } @@ -210,3 +211,31 @@ func requestWithCtx(r *http.Request, webCtx *Context) *http.Request { ctx := WithContext(r.Context(), webCtx) return r.WithContext(ctx) } + +func defaultJsonRender(ctx *Context, err error, result interface{}) { + + var code = 0 + var message = "" + if nil != err { + var e HttpError + if errors.As(err, &e) { + code = e.Code + message = e.Message + } else { + code = http.StatusInternalServerError + message = err.Error() + + if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) { + code = http.StatusBadRequest + } + } + } + + type response struct { + Code int `json:"code"` + Message string `json:"message,omitempty"` + Data interface{} `json:"data"` + } + + ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result}) +} diff --git a/web/binding/binding.go b/web/binding/binding.go index c299361a..87f48945 100644 --- a/web/binding/binding.go +++ b/web/binding/binding.go @@ -44,10 +44,10 @@ const ( type Request interface { ContentType() string - Header(key string) string - Cookie(name string) string - PathParam(name string) string - QueryParam(name string) string + Header(key string) (string, bool) + Cookie(name string) (string, bool) + PathParam(name string) (string, bool) + QueryParam(name string) (string, bool) FormParams() (url.Values, error) MultipartParams(maxMemory int64) (*multipart.Form, error) RequestBody() io.Reader @@ -70,7 +70,7 @@ var scopeTags = map[BindScope]string{ BindScopeCookie: "cookie", } -var scopeGetters = map[BindScope]func(r Request, name string) string{ +var scopeGetters = map[BindScope]func(r Request, name string) (string, bool){ BindScopeURI: Request.PathParam, BindScopeQuery: Request.QueryParam, BindScopeHeader: Request.Header, @@ -95,6 +95,11 @@ func RegisterBodyBinder(mime string, binder BodyBinder) { bodyBinders[mime] = binder } +// Bind checks the Method and Content-Type to select a binding engine automatically, +// Depending on the "Content-Type" header different bindings are used, for example: +// +// "application/json" --> JSON binding +// "application/xml" --> XML binding func Bind(i interface{}, r Request) error { if err := bindScope(i, r); err != nil { return fmt.Errorf("%w: %v", ErrBinding, err) @@ -138,8 +143,7 @@ func bindScope(i interface{}, r Request) error { fv := ev.Field(j) ft := et.Field(j) for scope := BindScopeURI; scope < BindScopeBody; scope++ { - err := bindScopeField(scope, fv, ft, r) - if err != nil { + if err := bindScopeField(scope, fv, ft, r); err != nil { return err } } @@ -149,14 +153,11 @@ func bindScope(i interface{}, r Request) error { func bindScopeField(scope BindScope, v reflect.Value, field reflect.StructField, r Request) error { if tag, loaded := scopeTags[scope]; loaded { - if name, ok := field.Tag.Lookup(tag); ok { - if name == "-" { - return nil // ignore bind - } - val := scopeGetters[scope](r, name) - err := bindData(v, val) - if err != nil { - return err + if name, ok := field.Tag.Lookup(tag); ok && name != "-" { + if val, exists := scopeGetters[scope](r, name); exists { + if err := bindData(v, val); err != nil { + return err + } } } } diff --git a/web/binding/binding_test.go b/web/binding/binding_test.go index df07d23f..0bff8fd6 100644 --- a/web/binding/binding_test.go +++ b/web/binding/binding_test.go @@ -44,20 +44,24 @@ func (r *MockRequest) ContentType() string { return r.contentType } -func (r *MockRequest) Header(key string) string { - return r.headers[key] +func (r *MockRequest) Header(key string) (string, bool) { + value, ok := r.headers[key] + return value, ok } -func (r *MockRequest) Cookie(name string) string { - return r.cookies[name] +func (r *MockRequest) Cookie(name string) (string, bool) { + value, ok := r.cookies[name] + return value, ok } -func (r *MockRequest) QueryParam(name string) string { - return r.queryParams[name] +func (r *MockRequest) QueryParam(name string) (string, bool) { + value, ok := r.queryParams[name] + return value, ok } -func (r *MockRequest) PathParam(name string) string { - return r.pathParams[name] +func (r *MockRequest) PathParam(name string) (string, bool) { + value, ok := r.pathParams[name] + return value, ok } func (r *MockRequest) FormParams() (url.Values, error) { diff --git a/web/context.go b/web/context.go index b6f39ef4..2013416d 100644 --- a/web/context.go +++ b/web/context.go @@ -23,6 +23,7 @@ import ( "mime/multipart" "net" "net/http" + "net/textproto" "net/url" "strings" "unicode" @@ -71,38 +72,43 @@ func (c *Context) ContentType() string { } // Header returns the named header in the request. -func (c *Context) Header(key string) string { - return c.Request.Header.Get(key) +func (c *Context) Header(key string) (string, bool) { + if values, ok := c.Request.Header[textproto.CanonicalMIMEHeaderKey(key)]; ok && len(values) > 0 { + return values[0], true + } + return "", false } // Cookie returns the named cookie provided in the request. -func (c *Context) Cookie(name string) string { +func (c *Context) Cookie(name string) (string, bool) { cookie, err := c.Request.Cookie(name) if err != nil { - return "" + return "", false + } + if val, err := url.QueryUnescape(cookie.Value); nil == err { + return val, true } - val, _ := url.QueryUnescape(cookie.Value) - return val + return cookie.Value, true } // PathParam returns the named variables in the request. -func (c *Context) PathParam(name string) string { +func (c *Context) PathParam(name string) (string, bool) { if params := mux.Vars(c.Request); nil != params { if value, ok := params[name]; ok { - return value + return value, true } } - return "" + return "", false } // QueryParam returns the named query in the request. -func (c *Context) QueryParam(name string) string { +func (c *Context) QueryParam(name string) (string, bool) { if values := c.Request.URL.Query(); nil != values { if value, ok := values[name]; ok && len(value) > 0 { - return value[0] + return value[0], true } } - return "" + return "", false } // FormParams returns the form in the request. @@ -183,6 +189,15 @@ func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, }) } +// Bind checks the Method and Content-Type to select a binding engine automatically, +// Depending on the "Content-Type" header different bindings are used, for example: +// +// "application/json" --> JSON binding +// "application/xml" --> XML binding +func (c *Context) Bind(r interface{}) error { + return binding.Bind(r, c) +} + // Render writes the response headers and calls render.Render to render data. func (c *Context) Render(code int, render render.Renderer) error { if code > 0 { diff --git a/web/examples/greeting/main.go b/web/examples/greeting/main.go index c7f85722..c1d0818a 100644 --- a/web/examples/greeting/main.go +++ b/web/examples/greeting/main.go @@ -43,7 +43,6 @@ func (g *Greeting) OnInit(ctx context.Context) error { g.Server.Use(func(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - start := time.Now() handler.ServeHTTP(writer, request) g.Logger.Info("http handle cost", diff --git a/web/options.go b/web/options.go index 9a1fc2cd..88bb54a1 100644 --- a/web/options.go +++ b/web/options.go @@ -16,7 +16,10 @@ package web -import "time" +import ( + "crypto/tls" + "time" +) type Options struct { // Addr optionally specifies the TCP address for the server to listen on, @@ -74,3 +77,15 @@ type Options struct { // If zero, DefaultMaxHeaderBytes is used. MaxHeaderBytes int `value:"${max-header-bytes:=0}"` } + +func (options Options) IsTls() bool { + return len(options.CertFile) > 0 && len(options.KeyFile) > 0 +} + +func (options Options) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile) + if err != nil { + return nil, err + } + return &cert, nil +} diff --git a/web/server.go b/web/server.go index 3b61543f..bba67ac1 100644 --- a/web/server.go +++ b/web/server.go @@ -19,10 +19,7 @@ package web import ( "context" "crypto/tls" - "errors" "net/http" - - "github.com/go-spring-projects/go-spring/web/binding" ) // A Server defines parameters for running an HTTP server. @@ -42,53 +39,18 @@ func NewServer(router *Router, options Options) *Server { } var tlsConfig *tls.Config - if len(options.CertFile) > 0 && len(options.KeyFile) > 0 { + if options.IsTls() { tlsConfig = &tls.Config{ - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile) - if err != nil { - return nil, err - } - return &cert, nil - }, - } - } - - var jsonRenderer = func(ctx *Context, err error, result interface{}) { - - var code = 0 - var message = "" - if nil != err { - var e HttpError - if errors.As(err, &e) { - code = e.Code - message = e.Message - } else { - code = http.StatusInternalServerError - message = err.Error() - - if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) { - code = http.StatusBadRequest - } - } + GetCertificate: options.GetCertificate, } - - type response struct { - Code int `json:"code"` - Message string `json:"message,omitempty"` - Data interface{} `json:"data"` - } - - ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result}) } return &Server{ options: options, router: router, - renderer: RendererFunc(jsonRenderer), + renderer: RendererFunc(defaultJsonRender), httpSvr: &http.Server{ Addr: addr, - Handler: router, TLSConfig: tlsConfig, ReadTimeout: options.ReadTimeout, ReadHeaderTimeout: options.ReadHeaderTimeout, @@ -108,6 +70,7 @@ func (s *Server) Addr() string { // calls Serve to handle requests on incoming connections. // Accepted connections are configured to enable TCP keep-alives. func (s *Server) Run() error { + s.httpSvr.Handler = s if nil != s.httpSvr.TLSConfig { return s.httpSvr.ListenAndServeTLS(s.options.CertFile, s.options.KeyFile) }