Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move otelhttp wrappers into internal package #5916

Merged
merged 22 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

- The deprecated `go.opentelemetry.io/contrib/processors/baggagecopy` package is removed. (#5853)

### Fixed

- Race condition when reading the HTTP body and writing the response in `go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp`. (#5916)

<!-- Released section -->
<!-- Don't change this section unless doing release -->

Expand Down
37 changes: 16 additions & 21 deletions instrumentation/net/http/otelhttp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/felixge/httpsnoop"

"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconv"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/semconvutil"
"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -166,14 +167,12 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http
}
}

var bw bodyWrapper
// if request body is nil or NoBody, we don't want to mutate the body as it
// will affect the identity of it in an unforeseeable way because we assert
// ReadCloser fulfills a certain interface and it is indeed nil or NoBody.
bw := request.NewBodyWrapper(r.Body, readRecordFunc)
if r.Body != nil && r.Body != http.NoBody {
bw.ReadCloser = r.Body
bw.record = readRecordFunc
r.Body = &bw
r.Body = bw
}

writeRecordFunc := func(int64) {}
Expand All @@ -183,13 +182,7 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http
}
}

rww := &respWriterWrapper{
ResponseWriter: w,
record: writeRecordFunc,
ctx: ctx,
props: h.propagators,
statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything
}
rww := request.NewRespWriterWrapper(w, writeRecordFunc)

// Wrap w to use our ResponseWriter methods while also exposing
// other interfaces that w may implement (http.CloseNotifier,
Expand Down Expand Up @@ -217,24 +210,26 @@ func (h *middleware) serveHTTP(w http.ResponseWriter, r *http.Request, next http

next.ServeHTTP(w, r.WithContext(ctx))

span.SetStatus(semconv.ServerStatus(rww.statusCode))
statusCode := rww.StatusCode()
bytesWritten := rww.BytesWritten()
span.SetStatus(semconv.ServerStatus(statusCode))
span.SetAttributes(h.traceSemconv.ResponseTraceAttrs(semconv.ResponseTelemetry{
StatusCode: rww.statusCode,
ReadBytes: bw.read.Load(),
ReadError: bw.err,
WriteBytes: rww.written,
WriteError: rww.err,
StatusCode: statusCode,
ReadBytes: bw.BytesRead(),
ReadError: bw.Error(),
WriteBytes: bytesWritten,
WriteError: rww.Error(),
})...)

// Add metrics
attributes := append(labeler.Get(), semconvutil.HTTPServerRequestMetrics(h.server, r)...)
if rww.statusCode > 0 {
attributes = append(attributes, semconv.HTTPStatusCode(rww.statusCode))
if statusCode > 0 {
attributes = append(attributes, semconv.HTTPStatusCode(statusCode))
}
o := metric.WithAttributeSet(attribute.NewSet(attributes...))

h.requestBytesCounter.Add(ctx, bw.read.Load(), o)
h.responseBytesCounter.Add(ctx, rww.written, o)
h.requestBytesCounter.Add(ctx, bw.BytesRead(), o)
h.responseBytesCounter.Add(ctx, bytesWritten, o)

// Use floating point division here for higher precision (instead of Millisecond method).
elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package request // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request"

import (
"io"
"sync"
)

var _ io.ReadCloser = &BodyWrapper{}

// BodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number
// of bytes read and the last error.
type BodyWrapper struct {
io.ReadCloser
OnRead func(n int64) // must not be nil

mu sync.Mutex
read int64
err error
}

// NewBodyWrapper creates a new BodyWrapper.
dmathieu marked this conversation as resolved.
Show resolved Hide resolved
func NewBodyWrapper(rc io.ReadCloser, onRead func(int64)) *BodyWrapper {
dmathieu marked this conversation as resolved.
Show resolved Hide resolved
return &BodyWrapper{
ReadCloser: rc,
OnRead: onRead,
}
}

// Read reads the data from the io.ReadCloser, and stores the number of bytes
// read and the error.
func (w *BodyWrapper) Read(b []byte) (int, error) {
n, err := w.ReadCloser.Read(b)
n1 := int64(n)

w.updateReadData(n1, err)
w.OnRead(n1)
return n, err
}

func (w *BodyWrapper) updateReadData(n int64, err error) {
w.mu.Lock()
defer w.mu.Unlock()

w.read = w.read + n
dmathieu marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
w.err = err
}
}

// Closes closes the io.ReadCloser.
func (w *BodyWrapper) Close() error {
return w.ReadCloser.Close()
}

// BytesRead returns the number of bytes read up to this point.
func (w *BodyWrapper) BytesRead() int64 {
w.mu.Lock()
defer w.mu.Unlock()

return w.read
}

// Error returns the last error.
func (w *BodyWrapper) Error() error {
w.mu.Lock()
defer w.mu.Unlock()

return w.err
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package request

import (
"errors"
"io"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBodyWrapper(t *testing.T) {
bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {})

data, err := io.ReadAll(bw)
require.NoError(t, err)
assert.Equal(t, "hello world", string(data))

assert.Equal(t, int64(11), bw.BytesRead())
assert.Equal(t, io.EOF, bw.Error())
}

type multipleErrorsReader struct {
calls int
}

type errorWrapper struct{}

func (errorWrapper) Error() string {
return "subsequent calls"
}

func (mer *multipleErrorsReader) Read([]byte) (int, error) {
mer.calls = mer.calls + 1
if mer.calls == 1 {
return 0, errors.New("first call")
dmathieu marked this conversation as resolved.
Show resolved Hide resolved
}

return 0, errorWrapper{}
}

func TestBodyWrapperWithErrors(t *testing.T) {
bw := NewBodyWrapper(io.NopCloser(&multipleErrorsReader{}), func(int64) {})

data, err := io.ReadAll(bw)
require.Equal(t, errors.New("first call"), err)
assert.Equal(t, "", string(data))
require.Equal(t, errors.New("first call"), bw.Error())

data, err = io.ReadAll(bw)
require.Equal(t, errorWrapper{}, err)
assert.Equal(t, "", string(data))
require.Equal(t, errorWrapper{}, bw.Error())
}

func TestConcurrentBodyWrapper(t *testing.T) {
bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {})

go func() {
_, _ = io.ReadAll(bw)
}()

assert.NotNil(t, bw.BytesRead())
assert.NoError(t, bw.Error())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package request // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/internal/request"

import (
"net/http"
"sync"
)

var _ http.ResponseWriter = &RespWriterWrapper{}

// RespWriterWrapper wraps a http.ResponseWriter in order to track the number of
// bytes written, the last error, and to catch the first written statusCode.
// TODO: The wrapped http.ResponseWriter doesn't implement any of the optional
// types (http.Hijacker, http.Pusher, http.CloseNotifier, etc)
// that may be useful when using it in real life situations.
type RespWriterWrapper struct {
http.ResponseWriter
OnWrite func(n int64) // must not be nil

mu sync.RWMutex
written int64
statusCode int
err error
wroteHeader bool
}

// NewRespWriterWrapper creates a new RespWriterWrapper.
func NewRespWriterWrapper(w http.ResponseWriter, onWrite func(int64)) *RespWriterWrapper {
return &RespWriterWrapper{
ResponseWriter: w,
OnWrite: onWrite,
statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything
}
}

// Header returns the response writer HTTP headers.
func (w *RespWriterWrapper) Header() http.Header {
return w.ResponseWriter.Header()

Check warning on line 40 in instrumentation/net/http/otelhttp/internal/request/resp_writer_wrapper.go

View check run for this annotation

Codecov / codecov/patch

instrumentation/net/http/otelhttp/internal/request/resp_writer_wrapper.go#L39-L40

Added lines #L39 - L40 were not covered by tests
dmathieu marked this conversation as resolved.
Show resolved Hide resolved
}

// Write writes the bytes array into the [ResponseWriter], and tracks the
// number of bytes written and last error.
func (w *RespWriterWrapper) Write(p []byte) (int, error) {
w.WriteHeader(http.StatusOK)

w.mu.RLock()
defer w.mu.RUnlock()
dmathieu marked this conversation as resolved.
Show resolved Hide resolved

n, err := w.ResponseWriter.Write(p)
Dismissed Show dismissed Hide dismissed
n1 := int64(n)
w.OnWrite(n1)
w.written += n1
w.err = err
return n, err
}

// WriteHeader persists initial statusCode for span attribution.
// All calls to WriteHeader will be propagated to the underlying ResponseWriter
// and will persist the statusCode from the first call.
// Blocking consecutive calls to WriteHeader alters expected behavior and will
// remove warning logs from net/http where developers will notice incorrect handler implementations.
func (w *RespWriterWrapper) WriteHeader(statusCode int) {
w.mu.Lock()
defer w.mu.Unlock()

if !w.wroteHeader {
w.wroteHeader = true
w.statusCode = statusCode
}
w.ResponseWriter.WriteHeader(statusCode)
}

// Flush implements [http.Flusher].
func (w *RespWriterWrapper) Flush() {
w.WriteHeader(http.StatusOK)

if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

// BytesWritten returns the number of bytes written.
func (w *RespWriterWrapper) BytesWritten() int64 {
w.mu.RLock()
defer w.mu.RUnlock()

return w.written
}

// BytesWritten returns the HTTP status code that was sent.
func (w *RespWriterWrapper) StatusCode() int {
w.mu.RLock()
defer w.mu.RUnlock()

return w.statusCode
}

// Error returns the last error.
func (w *RespWriterWrapper) Error() error {
w.mu.RLock()
defer w.mu.RUnlock()

return w.err
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package otelhttp
package request

import (
"net/http"
Expand All @@ -12,10 +12,7 @@ import (
)

func TestRespWriterWriteHeader(t *testing.T) {
rw := &respWriterWrapper{
ResponseWriter: &httptest.ResponseRecorder{},
record: func(int64) {},
}
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {})

rw.WriteHeader(http.StatusTeapot)
assert.Equal(t, http.StatusTeapot, rw.statusCode)
Expand All @@ -26,10 +23,7 @@ func TestRespWriterWriteHeader(t *testing.T) {
}

func TestRespWriterFlush(t *testing.T) {
rw := &respWriterWrapper{
ResponseWriter: &httptest.ResponseRecorder{},
record: func(int64) {},
}
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {})

rw.Flush()
assert.Equal(t, http.StatusOK, rw.statusCode)
Expand All @@ -49,12 +43,21 @@ func (_ nonFlushableResponseWriter) Write([]byte) (int, error) {
func (_ nonFlushableResponseWriter) WriteHeader(int) {}

func TestRespWriterFlushNoFlusher(t *testing.T) {
rw := &respWriterWrapper{
ResponseWriter: nonFlushableResponseWriter{},
record: func(int64) {},
}
rw := NewRespWriterWrapper(nonFlushableResponseWriter{}, func(int64) {})

rw.Flush()
assert.Equal(t, http.StatusOK, rw.statusCode)
assert.True(t, rw.wroteHeader)
}

func TestConcurrentRespWriterWrapper(t *testing.T) {
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {})

go func() {
_, _ = rw.Write([]byte("hello world"))
}()

assert.NotNil(t, rw.BytesWritten())
assert.NotNil(t, rw.StatusCode())
assert.NoError(t, rw.Error())
}
Loading