diff --git a/wrappers/net/http/client.go b/wrappers/net/http/client.go index 2737b9d..09c803c 100644 --- a/wrappers/net/http/client.go +++ b/wrappers/net/http/client.go @@ -95,10 +95,12 @@ func NewWrappedTracingTransport(rt http.RoundTripper, args ...context.Context) * // RoundTrip implements the RoundTripper interface to trace HTTP calls func (t *TracingTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + // reference to the tracer + tr := t.tracer // if the TracingTransport is created before the global tracer is created it will be nil - if t.tracer == nil { - t.tracer = internal.ExtractTracer(nil) - if t.tracer != nil && t.tracer.GetConfig().Debug { + if tr == nil { + tr = internal.ExtractTracer(nil) + if tr != nil && tr.GetConfig().Debug { log.Println("EPSAGON DEBUG: defaulting to global tracer in RoundTrip") } } @@ -111,6 +113,7 @@ func (t *TracingTransport) RoundTrip(req *http.Request) (resp *http.Response, er }() defer epsagon.GeneralEpsagonRecover("net.http.RoundTripper", "RoundTrip", t.tracer) startTime := tracer.GetTimestamp() + reqHeaders, reqBody := t.extractRequestData(req, tr) if !isBlacklistedURL(req.URL) { req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()} } @@ -118,26 +121,53 @@ func (t *TracingTransport) RoundTrip(req *http.Request) (resp *http.Response, er resp, err = t.transport.RoundTrip(req) called = true - event := postSuperCall(startTime, req.URL.String(), req.Method, resp, err, t.getMetadataOnly()) - t.addDataToEvent(req, resp, event) - t.tracer.AddEvent(event) + event := postSuperCall(startTime, req.URL.String(), req.Method, resp, err, t.getMetadataOnly(tr)) + t.addDataToEvent(reqHeaders, reqBody, req, event, tr) + tr.AddEvent(event) return } -func (t *TracingTransport) getMetadataOnly() bool { - return t.MetadataOnly || t.tracer.GetConfig().MetadataOnly +func (t *TracingTransport) getMetadataOnly(tr tracer.Tracer) bool { + return t.MetadataOnly || tr.GetConfig().MetadataOnly } -func (t *TracingTransport) addDataToEvent(req *http.Request, resp *http.Response, event *protocol.Event) { +func (t *TracingTransport) addDataToEvent(reqHeaders, reqBody string, req *http.Request, event *protocol.Event, tr tracer.Tracer) { if req != nil { addTraceIdToEvent(req, event) } - if resp != nil { - if !t.getMetadataOnly() { - updateRequestData(resp.Request, event.Resource.Metadata) - } + if !t.getMetadataOnly(tr) { + event.Resource.Metadata["request_headers"] = reqHeaders + event.Resource.Metadata["request_body"] = reqBody + } +} + +func (t *TracingTransport) extractRequestData(req *http.Request, tr tracer.Tracer) (headers string, body string) { + if t.getMetadataOnly(tr) { + return + } + + headers, err := formatHeaders(req.Header) + if err != nil { + headers = "" } + + if req.Body == nil { + return + } + + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + return + } + req.Body = ioutil.NopCloser(bytes.NewReader(buf)) + // truncates request body to the first 64KB + trimmed := buf + if len(buf) > MAX_METADATA_SIZE { + trimmed = buf[0:MAX_METADATA_SIZE] + } + body = string(trimmed) + return } func isBlacklistedURL(parsedUrl *url.URL) bool {