Skip to content

Commit

Permalink
Use buffer pool for reading data.
Browse files Browse the repository at this point in the history
  • Loading branch information
fancycode committed Dec 19, 2024
1 parent 684416f commit 5150d51
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
18 changes: 10 additions & 8 deletions backend_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
Expand All @@ -51,6 +50,7 @@ type BackendClient struct {

pool *HttpClientPool
capabilities *Capabilities
buffers BufferPool
}

func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) {
Expand Down Expand Up @@ -175,12 +175,14 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
return ErrUnsupportedContentType
}

body, err := io.ReadAll(resp.Body)
body, err := b.buffers.ReadAll(resp.Body)
if err != nil {
log.Printf("Could not read response body from %s: %s", req.URL, err)
return err
}

defer b.buffers.Put(body)

if isOcsRequest(u) || req.Header.Get("OCS-APIRequest") != "" {
// OCS response are wrapped in an OCS container that needs to be parsed
// to get the actual contents:
Expand All @@ -191,26 +193,26 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
// }
// }
var ocs OcsResponse
if err := json.Unmarshal(body, &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", string(body), req.URL, err)
if err := json.Unmarshal(body.Bytes(), &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", body.String(), req.URL, err)
return err
} else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 {
log.Printf("Incomplete OCS response %s from %s", string(body), req.URL)
log.Printf("Incomplete OCS response %s from %s", body.String(), req.URL)
return ErrIncompleteResponse
}

switch ocs.Ocs.Meta.StatusCode {
case http.StatusTooManyRequests:
log.Printf("Throttled OCS response %s from %s", string(body), req.URL)
log.Printf("Throttled OCS response %s from %s", body.String(), req.URL)
return ErrThrottledResponse
}

if err := json.Unmarshal(ocs.Ocs.Data, response); err != nil {
log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), req.URL, err)
return err
}
} else if err := json.Unmarshal(body, response); err != nil {
log.Printf("Could not decode response body %s from %s: %s", string(body), req.URL, err)
} else if err := json.Unmarshal(body.Bytes(), response); err != nil {
log.Printf("Could not decode response body %s from %s: %s", body.String(), req.URL, err)
return err
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions backend_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type BackendServer struct {

statsAllowedIps atomic.Pointer[AllowedIps]
invalidSecret []byte

buffers BufferPool
}

func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) {
Expand Down Expand Up @@ -284,14 +286,15 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque
return
}

body, err := io.ReadAll(r.Body)
body, err := b.buffers.ReadAll(r.Body)
if err != nil {
log.Println("Error reading body: ", err)
http.Error(w, "Could not read body", http.StatusBadRequest)
return
}
defer b.buffers.Put(body)

f(w, r, body)
f(w, r, body.Bytes())
}
}

Expand Down
13 changes: 8 additions & 5 deletions capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"context"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"net/url"
Expand Down Expand Up @@ -183,18 +182,20 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim
return e.errorIfMustRevalidate(ErrUnsupportedContentType)
}

body, err := io.ReadAll(response.Body)
body, err := e.c.buffers.ReadAll(response.Body)
if err != nil {
log.Printf("Could not read response body from %s: %s", url, err)
return e.errorIfMustRevalidate(err)
}

defer e.c.buffers.Put(body)

var ocs OcsResponse
if err := json.Unmarshal(body, &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", string(body), url, err)
if err := json.Unmarshal(body.Bytes(), &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", body.String(), url, err)
return e.errorIfMustRevalidate(err)
} else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 {
log.Printf("Incomplete OCS response %s from %s", string(body), url)
log.Printf("Incomplete OCS response %s from %s", body.String(), url)
return e.errorIfMustRevalidate(ErrIncompleteResponse)
}

Expand Down Expand Up @@ -240,6 +241,8 @@ type Capabilities struct {
pool *HttpClientPool
entries map[string]*capabilitiesEntry
nextInvalidate map[string]time.Time

buffers BufferPool
}

func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
Expand Down
12 changes: 3 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ func IsValidCountry(country string) bool {
var (
InvalidFormat = NewError("invalid_format", "Invalid data format.")

bufferPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
bufferPool BufferPool
)

type WritableClientMessage interface {
Expand Down Expand Up @@ -390,10 +386,8 @@ func (c *Client) ReadPump() {
continue
}

decodeBuffer := bufferPool.Get().(*bytes.Buffer)
decodeBuffer.Reset()
if _, err := decodeBuffer.ReadFrom(reader); err != nil {
bufferPool.Put(decodeBuffer)
decodeBuffer, err := bufferPool.ReadAll(reader)
if err != nil {
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading message from client %s: %v", sessionId, err)
} else {
Expand Down

0 comments on commit 5150d51

Please sign in to comment.