Skip to content

Commit

Permalink
Merge pull request #13 from go-spring-projects/fix-scope-param
Browse files Browse the repository at this point in the history
Ignore non-existent parameters when binding parameters
  • Loading branch information
limpo1989 authored Dec 6, 2023
2 parents ab3b8c6 + 7781a3c commit 8eebf1b
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 84 deletions.
41 changes: 35 additions & 6 deletions web/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package web

import (
"errors"
"fmt"
"net/http"
"reflect"
Expand Down Expand Up @@ -165,21 +166,21 @@ func validMappingFunc(fnType reflect.Type) error {
}

if fnType.NumIn() < 1 || fnType.NumIn() > 2 {
return fmt.Errorf("%s: invalid input parameter count", fnType.String())
return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [R, error]", fnType.String())
}

if fnType.NumOut() > 2 {
return fmt.Errorf("%s: invalid output parameter count", fnType.String())
return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)]", fnType.String())
}

if !utils.IsContextType(fnType.In(0)) {
return fmt.Errorf("%s: first input param type (%s) must be context", fnType.String(), fnType.In(0).String())
return fmt.Errorf("%s: expect func(ctx context.Context, [T]) [(R, error)", fnType.String())
}

if fnType.NumIn() > 1 {
argType := fnType.In(1)
if !(reflect.Struct == argType.Kind() || (reflect.Ptr == argType.Kind() && reflect.Struct == argType.Elem().Kind())) {
return fmt.Errorf("%s: second input param type (%s) must be struct/*struct", fnType.String(), argType.String())
return fmt.Errorf("%s: input param type (%s) must be struct/*struct", fnType.String(), argType.String())
}
}

Expand All @@ -188,11 +189,11 @@ func validMappingFunc(fnType reflect.Type) error {
case 1: // R | error
case 2: // (R, error)
if utils.IsErrorType(fnType.Out(0)) {
return fmt.Errorf("%s: first output param type not be error", fnType.String())
return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String())
}

if !utils.IsErrorType(fnType.Out(1)) {
return fmt.Errorf("%s: second output type (%s) must a error", fnType.String(), fnType.Out(1).String())
return fmt.Errorf("%s: expect func(...) (R, error)", fnType.String())
}
}

Expand All @@ -210,3 +211,31 @@ func requestWithCtx(r *http.Request, webCtx *Context) *http.Request {
ctx := WithContext(r.Context(), webCtx)
return r.WithContext(ctx)
}

func defaultJsonRender(ctx *Context, err error, result interface{}) {

var code = 0
var message = ""
if nil != err {
var e HttpError
if errors.As(err, &e) {
code = e.Code
message = e.Message
} else {
code = http.StatusInternalServerError
message = err.Error()

if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) {
code = http.StatusBadRequest
}
}
}

type response struct {
Code int `json:"code"`
Message string `json:"message,omitempty"`
Data interface{} `json:"data"`
}

ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result})
}
31 changes: 16 additions & 15 deletions web/binding/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ const (

type Request interface {
ContentType() string
Header(key string) string
Cookie(name string) string
PathParam(name string) string
QueryParam(name string) string
Header(key string) (string, bool)
Cookie(name string) (string, bool)
PathParam(name string) (string, bool)
QueryParam(name string) (string, bool)
FormParams() (url.Values, error)
MultipartParams(maxMemory int64) (*multipart.Form, error)
RequestBody() io.Reader
Expand All @@ -70,7 +70,7 @@ var scopeTags = map[BindScope]string{
BindScopeCookie: "cookie",
}

var scopeGetters = map[BindScope]func(r Request, name string) string{
var scopeGetters = map[BindScope]func(r Request, name string) (string, bool){
BindScopeURI: Request.PathParam,
BindScopeQuery: Request.QueryParam,
BindScopeHeader: Request.Header,
Expand All @@ -95,6 +95,11 @@ func RegisterBodyBinder(mime string, binder BodyBinder) {
bodyBinders[mime] = binder
}

// Bind checks the Method and Content-Type to select a binding engine automatically,
// Depending on the "Content-Type" header different bindings are used, for example:
//
// "application/json" --> JSON binding
// "application/xml" --> XML binding
func Bind(i interface{}, r Request) error {
if err := bindScope(i, r); err != nil {
return fmt.Errorf("%w: %v", ErrBinding, err)
Expand Down Expand Up @@ -138,8 +143,7 @@ func bindScope(i interface{}, r Request) error {
fv := ev.Field(j)
ft := et.Field(j)
for scope := BindScopeURI; scope < BindScopeBody; scope++ {
err := bindScopeField(scope, fv, ft, r)
if err != nil {
if err := bindScopeField(scope, fv, ft, r); err != nil {
return err
}
}
Expand All @@ -149,14 +153,11 @@ func bindScope(i interface{}, r Request) error {

func bindScopeField(scope BindScope, v reflect.Value, field reflect.StructField, r Request) error {
if tag, loaded := scopeTags[scope]; loaded {
if name, ok := field.Tag.Lookup(tag); ok {
if name == "-" {
return nil // ignore bind
}
val := scopeGetters[scope](r, name)
err := bindData(v, val)
if err != nil {
return err
if name, ok := field.Tag.Lookup(tag); ok && name != "-" {
if val, exists := scopeGetters[scope](r, name); exists {
if err := bindData(v, val); err != nil {
return err
}
}
}
}
Expand Down
20 changes: 12 additions & 8 deletions web/binding/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,24 @@ func (r *MockRequest) ContentType() string {
return r.contentType
}

func (r *MockRequest) Header(key string) string {
return r.headers[key]
func (r *MockRequest) Header(key string) (string, bool) {
value, ok := r.headers[key]
return value, ok
}

func (r *MockRequest) Cookie(name string) string {
return r.cookies[name]
func (r *MockRequest) Cookie(name string) (string, bool) {
value, ok := r.cookies[name]
return value, ok
}

func (r *MockRequest) QueryParam(name string) string {
return r.queryParams[name]
func (r *MockRequest) QueryParam(name string) (string, bool) {
value, ok := r.queryParams[name]
return value, ok
}

func (r *MockRequest) PathParam(name string) string {
return r.pathParams[name]
func (r *MockRequest) PathParam(name string) (string, bool) {
value, ok := r.pathParams[name]
return value, ok
}

func (r *MockRequest) FormParams() (url.Values, error) {
Expand Down
39 changes: 27 additions & 12 deletions web/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"mime/multipart"
"net"
"net/http"
"net/textproto"
"net/url"
"strings"
"unicode"
Expand Down Expand Up @@ -71,38 +72,43 @@ func (c *Context) ContentType() string {
}

// Header returns the named header in the request.
func (c *Context) Header(key string) string {
return c.Request.Header.Get(key)
func (c *Context) Header(key string) (string, bool) {
if values, ok := c.Request.Header[textproto.CanonicalMIMEHeaderKey(key)]; ok && len(values) > 0 {
return values[0], true
}
return "", false
}

// Cookie returns the named cookie provided in the request.
func (c *Context) Cookie(name string) string {
func (c *Context) Cookie(name string) (string, bool) {
cookie, err := c.Request.Cookie(name)
if err != nil {
return ""
return "", false
}
if val, err := url.QueryUnescape(cookie.Value); nil == err {
return val, true
}
val, _ := url.QueryUnescape(cookie.Value)
return val
return cookie.Value, true
}

// PathParam returns the named variables in the request.
func (c *Context) PathParam(name string) string {
func (c *Context) PathParam(name string) (string, bool) {
if params := mux.Vars(c.Request); nil != params {
if value, ok := params[name]; ok {
return value
return value, true
}
}
return ""
return "", false
}

// QueryParam returns the named query in the request.
func (c *Context) QueryParam(name string) string {
func (c *Context) QueryParam(name string) (string, bool) {
if values := c.Request.URL.Query(); nil != values {
if value, ok := values[name]; ok && len(value) > 0 {
return value[0]
return value[0], true
}
}
return ""
return "", false
}

// FormParams returns the form in the request.
Expand Down Expand Up @@ -183,6 +189,15 @@ func (c *Context) SetCookie(name, value string, maxAge int, path, domain string,
})
}

// Bind checks the Method and Content-Type to select a binding engine automatically,
// Depending on the "Content-Type" header different bindings are used, for example:
//
// "application/json" --> JSON binding
// "application/xml" --> XML binding
func (c *Context) Bind(r interface{}) error {
return binding.Bind(r, c)
}

// Render writes the response headers and calls render.Render to render data.
func (c *Context) Render(code int, render render.Renderer) error {
if code > 0 {
Expand Down
1 change: 0 additions & 1 deletion web/examples/greeting/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func (g *Greeting) OnInit(ctx context.Context) error {
g.Server.Use(func(handler http.Handler) http.Handler {

return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

start := time.Now()
handler.ServeHTTP(writer, request)
g.Logger.Info("http handle cost",
Expand Down
17 changes: 16 additions & 1 deletion web/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package web

import "time"
import (
"crypto/tls"
"time"
)

type Options struct {
// Addr optionally specifies the TCP address for the server to listen on,
Expand Down Expand Up @@ -74,3 +77,15 @@ type Options struct {
// If zero, DefaultMaxHeaderBytes is used.
MaxHeaderBytes int `value:"${max-header-bytes:=0}"`
}

func (options Options) IsTls() bool {
return len(options.CertFile) > 0 && len(options.KeyFile) > 0
}

func (options Options) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile)
if err != nil {
return nil, err
}
return &cert, nil
}
45 changes: 4 additions & 41 deletions web/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ package web
import (
"context"
"crypto/tls"
"errors"
"net/http"

"github.com/go-spring-projects/go-spring/web/binding"
)

// A Server defines parameters for running an HTTP server.
Expand All @@ -42,53 +39,18 @@ func NewServer(router *Router, options Options) *Server {
}

var tlsConfig *tls.Config
if len(options.CertFile) > 0 && len(options.KeyFile) > 0 {
if options.IsTls() {
tlsConfig = &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(options.CertFile, options.KeyFile)
if err != nil {
return nil, err
}
return &cert, nil
},
}
}

var jsonRenderer = func(ctx *Context, err error, result interface{}) {

var code = 0
var message = ""
if nil != err {
var e HttpError
if errors.As(err, &e) {
code = e.Code
message = e.Message
} else {
code = http.StatusInternalServerError
message = err.Error()

if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) {
code = http.StatusBadRequest
}
}
GetCertificate: options.GetCertificate,
}

type response struct {
Code int `json:"code"`
Message string `json:"message,omitempty"`
Data interface{} `json:"data"`
}

ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result})
}

return &Server{
options: options,
router: router,
renderer: RendererFunc(jsonRenderer),
renderer: RendererFunc(defaultJsonRender),
httpSvr: &http.Server{
Addr: addr,
Handler: router,
TLSConfig: tlsConfig,
ReadTimeout: options.ReadTimeout,
ReadHeaderTimeout: options.ReadHeaderTimeout,
Expand All @@ -108,6 +70,7 @@ func (s *Server) Addr() string {
// calls Serve to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) Run() error {
s.httpSvr.Handler = s
if nil != s.httpSvr.TLSConfig {
return s.httpSvr.ListenAndServeTLS(s.options.CertFile, s.options.KeyFile)
}
Expand Down

0 comments on commit 8eebf1b

Please sign in to comment.