Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non buffered body server responses. #1657

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions bytesconv.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,23 @@ func appendQuotedPath(dst, src []byte) []byte {
}
return dst
}

// countHexDigits returns the number of hex digits required to represent n when using writeHexInt
func countHexDigits(n int) int {
if n < 0 {
// developer sanity-check
panic("BUG: int must be positive")
}

if n == 0 {
return 1
}

count := 0
for n > 0 {
n = n >> 4
count++
}

return count
}
3 changes: 3 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ type Response struct {
raddr net.Addr
// Local TCPAddr from concurrently net.Conn
laddr net.Addr

headersWritten bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this state needs to be kept here, you can just add a headersWritten bool to UnbufferedWriterHttp1.

}

// SetHost sets host for the request.
Expand Down Expand Up @@ -1122,6 +1124,7 @@ func (resp *Response) Reset() {
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
resp.headersWritten = false
}

func (resp *Response) resetSkipHeader() {
Expand Down
70 changes: 70 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,48 @@ type RequestCtx struct {
hijackHandler HijackHandler
hijackNoResponse bool
formValueFunc FormValueFunc

disableBuffering bool // disables buffered response body
getUnbufferedWriter func(*RequestCtx) UnbufferedWriter // defines how to get unbuffered writer
unbufferedWriter UnbufferedWriter // writes directly to underlying connection
bytesSent int // number of bytes sent to client using unbuffered operations
}

// DisableBuffering modifies fasthttp to disable body buffering for this request.
// This is useful for requests that return large data or stream data.
//
// When buffering is disabled you must:
// 1. Set response status and header values before writing body
// 2. Set ContentLength is optional. If not set, the server will use chunked encoding.
// 3. Write body data using methods like ctx.Write or io.Copy(ctx,src), etc.
// 4. Optionally call CloseResponse to finalize the response.
//
// CLosing the response will finalize the response and send the last chunk.
// If the handler does not finish the response, it will be called automatically after handler returns.
// Closing the response will also set BytesSent with the correct number of total bytes sent.
func (ctx *RequestCtx) DisableBuffering() {
ctx.disableBuffering = true

// We need to create a new unbufferedWriter for each unbuffered request.
// This way we can allow different implementations and be compatible with http2 protocol
if ctx.unbufferedWriter == nil {
if ctx.getUnbufferedWriter != nil {
ctx.unbufferedWriter = ctx.getUnbufferedWriter(ctx)
} else {
ctx.unbufferedWriter = NewUnbufferedWriter(ctx)
}
}
}

// CloseResponse finalizes non-buffered response dispatch.
// This method must be called after performing non-buffered responses
// If the handler does not finish the response, it will be called automatically
// after the handler function returns.
func (ctx *RequestCtx) CloseResponse() error {
if !ctx.disableBuffering || ctx.unbufferedWriter == nil {
return ErrNotUnbuffered
}
return ctx.unbufferedWriter.Close()
}

// HijackHandler must process the hijacked connection c.
Expand Down Expand Up @@ -822,6 +864,11 @@ func (ctx *RequestCtx) reset() {

ctx.hijackHandler = nil
ctx.hijackNoResponse = false

ctx.disableBuffering = false
ctx.unbufferedWriter = nil
ctx.getUnbufferedWriter = nil
ctx.bytesSent = 0
}

type firstByteReader struct {
Expand Down Expand Up @@ -1443,10 +1490,28 @@ func (ctx *RequestCtx) NotFound() {

// Write writes p into response body.
func (ctx *RequestCtx) Write(p []byte) (int, error) {
if ctx.disableBuffering {
return ctx.writeDirect(p)
}

ctx.Response.AppendBody(p)
return len(p), nil
}

// writeDirect writes p to underlying connection bypassing any buffering.
func (ctx *RequestCtx) writeDirect(p []byte) (int, error) {
if ctx.unbufferedWriter == nil {
ctx.unbufferedWriter = NewUnbufferedWriter(ctx)
}
return ctx.unbufferedWriter.Write(p)
}

// BytesSent returns the number of bytes sent to the client after non buffered operation.
// Includes headers and body length.
func (ctx *RequestCtx) BytesSent() int {
return ctx.bytesSent
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this function?

Copy link
Author

@pablolagos pablolagos Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After streaming a response, it can be necessary to know the number of bytes sent. That number is known only if we a serving local resources, but unknown if we are proxying from external sources. Bytes sent can represent a cost related to data-transfer. It can be useful for logging and other analysis.

We could return that value in ctx.CloseReponse()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would remove this. This isn't supported with normal responses either so it's weird to only add it for this case now.


// WriteString appends s to response body.
func (ctx *RequestCtx) WriteString(s string) (int, error) {
ctx.Response.AppendBodyString(s)
Expand Down Expand Up @@ -2359,6 +2424,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
s.Handler(ctx)
}

if ctx.disableBuffering {
_ = ctx.CloseResponse()
break
}

timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
// Acquire a new ctx because the old one will still be in use by the timeout out handler.
Expand Down
70 changes: 70 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4237,6 +4237,76 @@ func TestServerChunkedResponse(t *testing.T) {
}
}

func TestServerDisableBuffering(t *testing.T) {
t.Parallel()

received := make(chan bool)
done := make(chan bool)

expectedBody := bytes.Repeat([]byte("a"), 4096)

s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.DisableBuffering()
ctx.SetStatusCode(StatusOK)
ctx.SetContentType("text/html; charset=utf-8")
reader := bytes.NewReader(expectedBody)
_, err := io.Copy(ctx, reader)
if err != nil {
t.Fatalf("Unexpected error when copying body: %v", err)
}
pablolagos marked this conversation as resolved.
Show resolved Hide resolved
ctx.CloseResponse()
if len(ctx.Response.Body()) > 0 {
t.Fatalf("Body was populated when buffer was disabled")
}

// wait until body is received by the consumer or stop after 2 seconds timeout
select {
case <-received:
case <-time.After(2 * time.Second):
t.Fatal("Body not received by consumer after 2 seconds")
}

// The consumer received the body, so we can finish the test
done <- true
},
}

ln := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()

conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET /index.html HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)

var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.ContentLength() != -1 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1)
}
if !bytes.Equal(resp.Body(), expectedBody) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}

// Signal that the body was received correctly
received <- true

// Wait until the server has finished
<-done
}

func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
Expand Down
116 changes: 116 additions & 0 deletions unbuffered.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package fasthttp

import (
"bufio"
"errors"
"fmt"
)

type UnbufferedWriter interface {
Write(p []byte) (int, error)
WriteHeaders() (int, error)
Close() error
}

type UnbufferedWriterHttp1 struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this type should be exposed, it can be kept internally. I also wouldn't add Http1 to the name as everything in fastthttp is http1.

writer *bufio.Writer
ctx *RequestCtx
bodyChunkStarted bool
bodyLastChunkSent bool
}

var ErrNotUnbuffered = errors.New("not unbuffered")
var ErrClosedUnbufferedWriter = errors.New("closed unbuffered writer")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it something like use of closed unbuffered writer


// Ensure UnbufferedWriterHttp1 implements UnbufferedWriter.
var _ UnbufferedWriter = &UnbufferedWriterHttp1{}

// NewUnbufferedWriter
//
// Object must be discarded when request is finished
func NewUnbufferedWriter(ctx *RequestCtx) *UnbufferedWriterHttp1 {
writer := acquireWriter(ctx)
return &UnbufferedWriterHttp1{ctx: ctx, writer: writer}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can also be kept internal then.


func (uw *UnbufferedWriterHttp1) Write(p []byte) (int, error) {
if uw.writer == nil || uw.ctx == nil {
return 0, ErrClosedUnbufferedWriter
}

// Write headers if not already sent
if !uw.ctx.Response.headersWritten {
_, err := uw.WriteHeaders()
if err != nil {
return 0, fmt.Errorf("error writing headers: %w", err)
}
}

// Write body. In chunks if content length is not set.
if uw.ctx.Response.Header.contentLength == -1 && uw.ctx.Response.Header.IsHTTP11() {
uw.bodyChunkStarted = true
err := writeChunk(uw.writer, p)
if err != nil {
return 0, err
}
uw.ctx.bytesSent += len(p) + 4 + countHexDigits(len(p))
return len(p), nil
}

n, err := uw.writer.Write(p)
uw.ctx.bytesSent += n

return n, err
}

func (uw *UnbufferedWriterHttp1) WriteHeaders() (int, error) {
if uw.writer == nil || uw.ctx == nil {
return 0, ErrClosedUnbufferedWriter
}

if !uw.ctx.Response.headersWritten {
if uw.ctx.Response.Header.contentLength == 0 && uw.ctx.Response.Header.IsHTTP11() {
if uw.ctx.Response.SkipBody {
uw.ctx.Response.Header.SetContentLength(0)
} else {
uw.ctx.Response.Header.SetContentLength(-1) // means Transfer-Encoding = chunked
}
}
h := uw.ctx.Response.Header.Header()
n, err := uw.writer.Write(h)
Comment on lines +79 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably faster and more efficient to use uw.ctx.Response.Header.WriteTo(uw.writer)

if err != nil {
return 0, err
}
uw.ctx.bytesSent += n
uw.ctx.Response.headersWritten = true
}
return 0, nil
}

func (uw *UnbufferedWriterHttp1) Close() error {
if uw.writer == nil || uw.ctx == nil {
return ErrClosedUnbufferedWriter
}

// write headers if not already sent (e.g. if there is no body written)
if !uw.ctx.Response.headersWritten {
// skip body, as we are closing without writing body
uw.ctx.Response.SkipBody = true
_, err := uw.WriteHeaders()
if err != nil {
return fmt.Errorf("error writing headers: %w", err)
}
}
Comment on lines +96 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test to make sure this works, calling CloseResponse() without writing anything to the body.


// finalize chunks
if uw.bodyChunkStarted && uw.ctx.Response.Header.IsHTTP11() && !uw.bodyLastChunkSent {
_, _ = uw.writer.Write([]byte("0\r\n\r\n"))
uw.ctx.bytesSent += 5
}
_ = uw.writer.Flush()
uw.bodyLastChunkSent = true
releaseWriter(uw.ctx.s, uw.writer)
uw.writer = nil
uw.ctx = nil
return nil
}