diff --git a/backend_client.go b/backend_client.go index 8b60869f..070b102e 100644 --- a/backend_client.go +++ b/backend_client.go @@ -27,7 +27,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "log" "net/http" "net/url" @@ -51,6 +50,7 @@ type BackendClient struct { pool *HttpClientPool capabilities *Capabilities + buffers BufferPool } func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) { @@ -175,12 +175,14 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return ErrUnsupportedContentType } - body, err := io.ReadAll(resp.Body) + body, err := b.buffers.ReadAll(resp.Body) if err != nil { log.Printf("Could not read response body from %s: %s", req.URL, err) return err } + defer b.buffers.Put(body) + if isOcsRequest(u) || req.Header.Get("OCS-APIRequest") != "" { // OCS response are wrapped in an OCS container that needs to be parsed // to get the actual contents: @@ -191,17 +193,17 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ // } // } var ocs OcsResponse - if err := json.Unmarshal(body, &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", string(body), req.URL, err) + if err := json.Unmarshal(body.Bytes(), &ocs); err != nil { + log.Printf("Could not decode OCS response %s from %s: %s", body.String(), req.URL, err) return err } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { - log.Printf("Incomplete OCS response %s from %s", string(body), req.URL) + log.Printf("Incomplete OCS response %s from %s", body.String(), req.URL) return ErrIncompleteResponse } switch ocs.Ocs.Meta.StatusCode { case http.StatusTooManyRequests: - log.Printf("Throttled OCS response %s from %s", string(body), req.URL) + log.Printf("Throttled OCS response %s from %s", body.String(), req.URL) return ErrThrottledResponse } @@ -209,8 +211,8 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), req.URL, err) return err } - } else if err := json.Unmarshal(body, response); err != nil { - log.Printf("Could not decode response body %s from %s: %s", string(body), req.URL, err) + } else if err := json.Unmarshal(body.Bytes(), response); err != nil { + log.Printf("Could not decode response body %s from %s: %s", body.String(), req.URL, err) return err } return nil diff --git a/backend_server.go b/backend_server.go index 016f0eb1..1ce07bfb 100644 --- a/backend_server.go +++ b/backend_server.go @@ -70,6 +70,8 @@ type BackendServer struct { statsAllowedIps atomic.Pointer[AllowedIps] invalidSecret []byte + + buffers BufferPool } func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) { @@ -284,14 +286,15 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque return } - body, err := io.ReadAll(r.Body) + body, err := b.buffers.ReadAll(r.Body) if err != nil { log.Println("Error reading body: ", err) http.Error(w, "Could not read body", http.StatusBadRequest) return } + defer b.buffers.Put(body) - f(w, r, body) + f(w, r, body.Bytes()) } } diff --git a/capabilities.go b/capabilities.go index e606bf03..554ca79c 100644 --- a/capabilities.go +++ b/capabilities.go @@ -25,7 +25,6 @@ import ( "context" "encoding/json" "errors" - "io" "log" "net/http" "net/url" @@ -183,18 +182,20 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim return e.errorIfMustRevalidate(ErrUnsupportedContentType) } - body, err := io.ReadAll(response.Body) + body, err := e.c.buffers.ReadAll(response.Body) if err != nil { log.Printf("Could not read response body from %s: %s", url, err) return e.errorIfMustRevalidate(err) } + defer e.c.buffers.Put(body) + var ocs OcsResponse - if err := json.Unmarshal(body, &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", string(body), url, err) + if err := json.Unmarshal(body.Bytes(), &ocs); err != nil { + log.Printf("Could not decode OCS response %s from %s: %s", body.String(), url, err) return e.errorIfMustRevalidate(err) } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { - log.Printf("Incomplete OCS response %s from %s", string(body), url) + log.Printf("Incomplete OCS response %s from %s", body.String(), url) return e.errorIfMustRevalidate(ErrIncompleteResponse) } @@ -240,6 +241,8 @@ type Capabilities struct { pool *HttpClientPool entries map[string]*capabilitiesEntry nextInvalidate map[string]time.Time + + buffers BufferPool } func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) { diff --git a/client.go b/client.go index 3980218c..145e7d91 100644 --- a/client.go +++ b/client.go @@ -82,11 +82,7 @@ func IsValidCountry(country string) bool { var ( InvalidFormat = NewError("invalid_format", "Invalid data format.") - bufferPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, - } + bufferPool BufferPool ) type WritableClientMessage interface { @@ -390,10 +386,8 @@ func (c *Client) ReadPump() { continue } - decodeBuffer := bufferPool.Get().(*bytes.Buffer) - decodeBuffer.Reset() - if _, err := decodeBuffer.ReadFrom(reader); err != nil { - bufferPool.Put(decodeBuffer) + decodeBuffer, err := bufferPool.ReadAll(reader) + if err != nil { if sessionId := c.GetSessionId(); sessionId != "" { log.Printf("Error reading message from client %s: %v", sessionId, err) } else {