diff --git a/web/bind.go b/web/bind.go index 8fc574e2..0b252244 100644 --- a/web/bind.go +++ b/web/bind.go @@ -41,8 +41,12 @@ func (fn RendererFunc) Render(ctx *Context, err error, result interface{}) { // // func(ctx context.Context) R // +// func(ctx context.Context) error +// // func(ctx context.Context, req T) R // +// func(ctx context.Context, req T) error +// // func(ctx context.Context, req T) (R, error) // // func(writer http.ResponseWriter, request *http.Request) @@ -53,11 +57,11 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { switch h := fn.(type) { case http.HandlerFunc: - return h + return warpHandlerCtx(h) case http.Handler: - return h.ServeHTTP + return warpHandlerCtx(h.ServeHTTP) case func(http.ResponseWriter, *http.Request): - return h + return warpHandlerCtx(h) default: // valid func if err := validMappingFunc(fnType); nil != err { @@ -65,6 +69,8 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { } } + firstOutIsErrorType := 1 == fnType.NumOut() && utils.IsErrorType(fnType.Out(0)) + return func(writer http.ResponseWriter, request *http.Request) { // param of context @@ -128,14 +134,15 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { // nothing return case 1: - // write response - result = returnValues[0].Interface() + if firstOutIsErrorType { + err, _ = returnValues[0].Interface().(error) + } else { + result = returnValues[0].Interface() + } case 2: // check error result = returnValues[0].Interface() - if e, ok := returnValues[1].Interface().(error); ok && nil != e { - err = e - } + err, _ = returnValues[1].Interface().(error) default: panic("unreachable here") } @@ -149,7 +156,9 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { func validMappingFunc(fnType reflect.Type) error { // func(ctx context.Context) // func(ctx context.Context) R + // func(ctx context.Context) error // func(ctx context.Context, req T) R + // func(ctx context.Context, req T) error // func(ctx context.Context, req T) (R, error) if !utils.IsFuncType(fnType) { return fmt.Errorf("%s: not a func", fnType.String()) @@ -174,13 +183,30 @@ func validMappingFunc(fnType reflect.Type) error { } } - if 0 < fnType.NumOut() && utils.IsErrorType(fnType.Out(0)) { - return fmt.Errorf("%s: first output param type not be error", fnType.String()) - } + switch fnType.NumOut() { + case 0: // nothing + case 1: // R | error + case 2: // (R, error) + if utils.IsErrorType(fnType.Out(0)) { + return fmt.Errorf("%s: first output param type not be error", fnType.String()) + } - if 1 < fnType.NumOut() && !utils.IsErrorType(fnType.Out(1)) { - return fmt.Errorf("%s: second output type (%s) must a error", fnType.String(), fnType.Out(1).String()) + if !utils.IsErrorType(fnType.Out(1)) { + return fmt.Errorf("%s: second output type (%s) must a error", fnType.String(), fnType.Out(1).String()) + } } return nil } + +func warpHandlerCtx(handler http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + webCtx := &Context{Writer: writer, Request: request} + handler.ServeHTTP(writer, requestWithCtx(request, webCtx)) + } +} + +func requestWithCtx(r *http.Request, webCtx *Context) *http.Request { + ctx := WithContext(r.Context(), webCtx) + return r.WithContext(ctx) +} diff --git a/web/context.go b/web/context.go index c1f46de7..b6f39ef4 100644 --- a/web/context.go +++ b/web/context.go @@ -186,6 +186,11 @@ func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, // Render writes the response headers and calls render.Render to render data. func (c *Context) Render(code int, render render.Renderer) error { if code > 0 { + if len(c.Writer.Header().Get("Content-Type")) <= 0 { + if contentType := render.ContentType(); len(contentType) > 0 { + c.Writer.Header().Set("Content-Type", contentType) + } + } c.Writer.WriteHeader(code) } return render.Render(c.Writer) @@ -203,7 +208,7 @@ func (c *Context) String(code int, format string, args ...interface{}) error { // Data writes some data into the body stream and updates the HTTP code. func (c *Context) Data(code int, contentType string, data []byte) error { - return c.Render(code, render.BinaryRenderer{ContentType: contentType, Data: data}) + return c.Render(code, render.BinaryRenderer{DataType: contentType, Data: data}) } // JSON serializes the given struct as JSON into the response body. diff --git a/web/examples/greeting/main.go b/web/examples/greeting/main.go index 067743ee..c7f85722 100644 --- a/web/examples/greeting/main.go +++ b/web/examples/greeting/main.go @@ -18,7 +18,6 @@ package main import ( "context" - "fmt" "log/slog" "math/rand" "mime/multipart" @@ -36,10 +35,10 @@ type Greeting struct { } func (g *Greeting) OnInit(ctx context.Context) error { - g.Server.Bind("/greeting", g.Greeting) - g.Server.Bind("/health", g.Health) - g.Server.Bind("/user/register/{username}/{password}", g.Register) - g.Server.Bind("/user/password", g.UpdatePassword) + g.Server.Get("/greeting", g.Greeting) + g.Server.Get("/health", g.Health) + g.Server.Post("/user/register/{username}/{password}", g.Register) + g.Server.Post("/user/password", g.UpdatePassword) g.Server.Use(func(handler http.Handler) http.Handler { @@ -61,11 +60,11 @@ func (g *Greeting) Greeting(ctx context.Context) string { return "greeting!!!" } -func (g *Greeting) Health(ctx context.Context) (string, error) { +func (g *Greeting) Health(ctx context.Context) error { if 0 == rand.Int()%2 { - return "", fmt.Errorf("health check failed") + return web.Error(400, "health check failed") } - return time.Now().String(), nil + return nil } func (g *Greeting) Register( diff --git a/web/render/binary.go b/web/render/binary.go index f25dc0f2..781c9113 100644 --- a/web/render/binary.go +++ b/web/render/binary.go @@ -21,18 +21,19 @@ import ( ) type BinaryRenderer struct { - ContentType string - Data []byte + DataType string // Content-Type + Data []byte } -func (b BinaryRenderer) Render(writer http.ResponseWriter) error { - if header := writer.Header(); len(header.Get("Content-Type")) == 0 { - contentType := "application/octet-stream" - if len(b.ContentType) > 0 { - contentType = b.ContentType - } - header.Set("Content-Type", contentType) +func (b BinaryRenderer) ContentType() string { + contentType := "application/octet-stream" + if len(b.DataType) > 0 { + contentType = b.DataType } + return contentType +} + +func (b BinaryRenderer) Render(writer http.ResponseWriter) error { _, err := writer.Write(b.Data) return err } diff --git a/web/render/binary_test.go b/web/render/binary_test.go index 444ff024..cd1df094 100644 --- a/web/render/binary_test.go +++ b/web/render/binary_test.go @@ -33,9 +33,10 @@ func TestBinaryRenderer(t *testing.T) { w := httptest.NewRecorder() - err := BinaryRenderer{ContentType: "application/octet-stream", Data: data}.Render(w) + render := BinaryRenderer{DataType: "application/octet-stream", Data: data} + err := render.Render(w) assert.Nil(t, err) - assert.Equal(t, w.Header().Get("Content-Type"), "application/octet-stream") + assert.Equal(t, render.ContentType(), "application/octet-stream") assert.Equal(t, w.Body.Bytes(), data) } diff --git a/web/render/html.go b/web/render/html.go index 603a1aa1..7ae2a4de 100644 --- a/web/render/html.go +++ b/web/render/html.go @@ -27,10 +27,11 @@ type HTMLRenderer struct { Data interface{} } +func (h HTMLRenderer) ContentType() string { + return "text/html; charset=utf-8" +} + func (h HTMLRenderer) Render(writer http.ResponseWriter) error { - if header := writer.Header(); len(header.Get("Content-Type")) == 0 { - header.Set("Content-Type", "text/html; charset=utf-8") - } if len(h.Name) > 0 { return h.Template.ExecuteTemplate(writer, h.Name, h.Data) } diff --git a/web/render/html_test.go b/web/render/html_test.go index d16e9e89..cd0a9a36 100644 --- a/web/render/html_test.go +++ b/web/render/html_test.go @@ -33,6 +33,6 @@ func TestHTMLRenderer(t *testing.T) { err := htmlRender.Render(w) assert.Nil(t, err) - assert.Equal(t, w.Header().Get("Content-Type"), "text/html; charset=utf-8") + assert.Equal(t, htmlRender.ContentType(), "text/html; charset=utf-8") assert.Equal(t, w.Body.String(), "Hello asdklajhdasdd") } diff --git a/web/render/json.go b/web/render/json.go index a455c5d0..c6169788 100644 --- a/web/render/json.go +++ b/web/render/json.go @@ -27,10 +27,11 @@ type JsonRenderer struct { Data interface{} } +func (j JsonRenderer) ContentType() string { + return "application/json; charset=utf-8" +} + func (j JsonRenderer) Render(writer http.ResponseWriter) error { - if header := writer.Header(); len(header.Get("Content-Type")) == 0 { - header.Set("Content-Type", "application/json; charset=utf-8") - } encoder := json.NewEncoder(writer) if len(j.Prefix) > 0 || len(j.Indent) > 0 { encoder.SetIndent(j.Prefix, j.Indent) diff --git a/web/render/json_test.go b/web/render/json_test.go index e4dc05a3..ab4bc411 100644 --- a/web/render/json_test.go +++ b/web/render/json_test.go @@ -31,9 +31,10 @@ func TestJSONRenderer(t *testing.T) { w := httptest.NewRecorder() - err := JsonRenderer{Data: data}.Render(w) + render := JsonRenderer{Data: data} + err := render.Render(w) assert.Nil(t, err) - assert.Equal(t, w.Header().Get("Content-Type"), "application/json; charset=utf-8") + assert.Equal(t, render.ContentType(), "application/json; charset=utf-8") assert.Equal(t, w.Body.String(), "{\"foo\":\"bar\",\"html\":\"\\u003cb\\u003e\"}\n") } diff --git a/web/render/redirect.go b/web/render/redirect.go index 02521516..d7581863 100644 --- a/web/render/redirect.go +++ b/web/render/redirect.go @@ -27,6 +27,10 @@ type RedirectRenderer struct { Location string } +func (r RedirectRenderer) ContentType() string { + return "" +} + func (r RedirectRenderer) Render(writer http.ResponseWriter) error { if (r.Code < http.StatusMultipleChoices || r.Code > http.StatusPermanentRedirect) && r.Code != http.StatusCreated { panic(fmt.Sprintf("Cannot redirect with status code %d", r.Code)) diff --git a/web/render/redirect_test.go b/web/render/redirect_test.go index c02d543a..e1084fde 100644 --- a/web/render/redirect_test.go +++ b/web/render/redirect_test.go @@ -37,6 +37,7 @@ func TestRedirectRenderer(t *testing.T) { w := httptest.NewRecorder() err = data1.Render(w) assert.Nil(t, err) + assert.Equal(t, data1.ContentType(), "") data2 := RedirectRenderer{ Code: http.StatusOK, diff --git a/web/render/renderer.go b/web/render/renderer.go index 7d228ea9..969e7161 100644 --- a/web/render/renderer.go +++ b/web/render/renderer.go @@ -16,9 +16,12 @@ package render -import "net/http" +import ( + "net/http" +) // Renderer writes data with custom ContentType and headers. type Renderer interface { + ContentType() string Render(writer http.ResponseWriter) error } diff --git a/web/render/text.go b/web/render/text.go index a7cff06b..f81d2081 100644 --- a/web/render/text.go +++ b/web/render/text.go @@ -28,10 +28,11 @@ type TextRenderer struct { Args []interface{} } +func (t TextRenderer) ContentType() string { + return "text/plain; charset=utf-8" +} + func (t TextRenderer) Render(writer http.ResponseWriter) error { - if header := writer.Header(); len(header.Get("Content-Type")) == 0 { - header.Set("Content-Type", "text/plain; charset=utf-8") - } _, err := io.Copy(writer, strings.NewReader(fmt.Sprintf(t.Format, t.Args...))) return err } diff --git a/web/render/text_test.go b/web/render/text_test.go index f30d3424..167efad9 100644 --- a/web/render/text_test.go +++ b/web/render/text_test.go @@ -26,12 +26,14 @@ import ( func TestTextRenderer(t *testing.T) { w := httptest.NewRecorder() - err := (TextRenderer{ + render := TextRenderer{ Format: "hello %s %d", Args: []any{"bob", 2}, - }).Render(w) + } + + err := render.Render(w) assert.Nil(t, err) - assert.Equal(t, w.Header().Get("Content-Type"), "text/plain; charset=utf-8") + assert.Equal(t, render.ContentType(), "text/plain; charset=utf-8") assert.Equal(t, w.Body.String(), "hello bob 2") } diff --git a/web/render/xml.go b/web/render/xml.go index bcea5d6d..19533b8f 100644 --- a/web/render/xml.go +++ b/web/render/xml.go @@ -27,11 +27,11 @@ type XmlRenderer struct { Data interface{} } -func (x XmlRenderer) Render(writer http.ResponseWriter) error { - if header := writer.Header(); len(header.Get("Content-Type")) == 0 { - header.Set("Content-Type", "application/xml; charset=utf-8") - } +func (x XmlRenderer) ContentType() string { + return "application/xml; charset=utf-8" +} +func (x XmlRenderer) Render(writer http.ResponseWriter) error { encoder := xml.NewEncoder(writer) if len(x.Prefix) > 0 || len(x.Indent) > 0 { encoder.Indent(x.Prefix, x.Indent) diff --git a/web/render/xml_test.go b/web/render/xml_test.go index 1e46c93f..cbc53658 100644 --- a/web/render/xml_test.go +++ b/web/render/xml_test.go @@ -54,9 +54,10 @@ func TestXmlRenderer(t *testing.T) { "foo": "bar", } - err := (XmlRenderer{Data: data}).Render(w) + render := (XmlRenderer{Data: data}) + err := render.Render(w) assert.Nil(t, err) - assert.Equal(t, w.Header().Get("Content-Type"), "application/xml; charset=utf-8") + assert.Equal(t, render.ContentType(), "application/xml; charset=utf-8") assert.Equal(t, w.Body.String(), "bar") } diff --git a/web/server.go b/web/server.go index 247fc0aa..3b61543f 100644 --- a/web/server.go +++ b/web/server.go @@ -176,9 +176,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.router.ServeHTTP(w, req) } -// Get returns a route registered with the given name. -func (s *Server) Get(name string) *Route { - return s.router.Get(name) +// GetRoute returns a route registered with the given name. +func (s *Server) GetRoute(name string) *Route { + return s.router.GetRoute(name) } // StrictSlash defines the trailing slash behavior for new routes. The initial @@ -250,20 +250,25 @@ func (s *Server) HandleFunc(path string, f func(http.ResponseWriter, *http.Reque } // Bind registers a new route with a matcher for the URL path. +// Automatic binding request to handler input params, following functions: // // func(ctx context.Context) // // func(ctx context.Context) R // +// func(ctx context.Context) error +// // func(ctx context.Context, req T) R // +// func(ctx context.Context, req T) error +// // func(ctx context.Context, req T) (R, error) -func (s *Server) Bind(path string, f interface{}, r ...Renderer) *Route { +func (s *Server) Bind(path string, handler interface{}, r ...Renderer) *Route { var renderer = s.renderer if len(r) > 0 { renderer = r[0] } - return s.Handle(path, Bind(f, renderer)) + return s.Handle(path, Bind(handler, renderer)) } // Headers registers a new route with a matcher for request header values. @@ -320,3 +325,57 @@ func (s *Server) BuildVarsFunc(f BuildVarsFunc) *Route { func (s *Server) Walk(walkFn WalkFunc) error { return s.router.Walk(walkFn) } + +// Get registers a new GET route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Get(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodGet) +} + +// Head registers a new HEAD route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Head(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodHead) +} + +// Post registers a new POST route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Post(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodPost) +} + +// Put registers a new PUT route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Put(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodPut) +} + +// Patch registers a new PATCH route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Patch(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodPatch) +} + +// Delete registers a new DELETE route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Delete(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodDelete) +} + +// Connect registers a new CONNECT route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Connect(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodConnect) +} + +// Options registers a new OPTIONS route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Options(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodOptions) +} + +// Trace registers a new TRACE route with a matcher for the URL path of the get method. +// See Server.Bind() +func (s *Server) Trace(path string, handler interface{}, r ...Renderer) *Route { + return s.Bind(path, handler, r...).Methods(http.MethodTrace) +}