diff --git a/go.mod b/go.mod index c3fc127..4fed25a 100644 --- a/go.mod +++ b/go.mod @@ -17,4 +17,5 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/xid v1.5.0 // indirect ) diff --git a/go.sum b/go.sum index a4ca215..cf655e6 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,7 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= diff --git a/requestlog/accesslogger.go b/requestlog/accesslogger.go new file mode 100644 index 0000000..bae344c --- /dev/null +++ b/requestlog/accesslogger.go @@ -0,0 +1,104 @@ +package requestlog + +import ( + "bytes" + "encoding/json" + "net/http" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" +) + +const MaxRequestSizeLog = 4 * 1024 +const MaxStringRequestSizeLog = MaxRequestSizeLog / 2 + +func AccessLogger(logOptions bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log := hlog.FromRequest(r) + + crw := &CountingResponseWriter{ + ResponseWriter: w, + ResponseLength: -1, + StatusCode: -1, + } + + start := time.Now() + next.ServeHTTP(crw, r) + requestDuration := time.Since(start) + + if r.Method == http.MethodOptions && !logOptions { + return + } + + var requestLog *zerolog.Event + if crw.StatusCode >= 500 { + requestLog = log.Error() + } else if crw.StatusCode >= 400 { + requestLog = log.Warn() + } else { + requestLog = log.Info() + } + + if userAgent := r.UserAgent(); userAgent != "" { + requestLog.Str("user_agent", userAgent) + } + if referer := r.Referer(); referer != "" { + requestLog.Str("referer", referer) + } + remoteAddr := r.RemoteAddr + + requestLog.Str("remote_addr", remoteAddr) + requestLog.Str("method", r.Method) + requestLog.Str("proto", r.Proto) + requestLog.Int64("request_length", r.ContentLength) + requestLog.Str("host", r.Host) + requestLog.Str("request_uri", r.RequestURI) + if r.Method != http.MethodGet && r.Method != http.MethodHead { + requestLog.Str("request_content_type", r.Header.Get("Content-Type")) + if crw.RequestBody != nil { + logRequestMaybeJSON(requestLog, "request_body", crw.RequestBody.Bytes()) + } + } + + // response + requestLog.Int64("request_time_ms", requestDuration.Milliseconds()) + requestLog.Int("status_code", crw.StatusCode) + requestLog.Int("response_length", crw.ResponseLength) + requestLog.Str("response_content_type", crw.Header().Get("Content-Type")) + if crw.ResponseBody != nil { + logRequestMaybeJSON(requestLog, "response_body", crw.ResponseBody.Bytes()) + } + + // don't log successful health requests + if r.URL.Path == "/health" && crw.StatusCode == http.StatusNoContent { + return + } + + requestLog.Msg("Access") + }) + } +} + +func logRequestMaybeJSON(evt *zerolog.Event, key string, data []byte) { + data = removeNewlines(data) + if json.Valid(data) { + evt.RawJSON(key, data) + } else { + // Logging as a string will create lots of escaping and it's not valid json anyway, so cut off a bit more + if len(data) > MaxStringRequestSizeLog { + data = data[:MaxStringRequestSizeLog] + } + evt.Bytes(key+"_invalid", data) + } +} + +func removeNewlines(data []byte) []byte { + data = bytes.TrimSpace(data) + if bytes.ContainsRune(data, '\n') { + data = bytes.ReplaceAll(data, []byte{'\n'}, []byte{}) + data = bytes.ReplaceAll(data, []byte{'\r'}, []byte{}) + } + return data +} diff --git a/requestlog/countingresponsewriter.go b/requestlog/countingresponsewriter.go new file mode 100644 index 0000000..98cd0d4 --- /dev/null +++ b/requestlog/countingresponsewriter.go @@ -0,0 +1,62 @@ +package requestlog + +import ( + "bufio" + "bytes" + "fmt" + "net" + "net/http" + "strings" +) + +type CountingResponseWriter struct { + StatusCode int + ResponseLength int + Hijacked bool + ResponseWriter http.ResponseWriter + ResponseBody *bytes.Buffer + RequestBody *bytes.Buffer +} + +func (crw *CountingResponseWriter) Header() http.Header { + return crw.ResponseWriter.Header() +} + +func (crw *CountingResponseWriter) Write(data []byte) (int, error) { + if crw.ResponseLength == -1 { + crw.ResponseLength = 0 + } + if crw.StatusCode == -1 { + crw.StatusCode = http.StatusOK + } + crw.ResponseLength += len(data) + + if crw.ResponseBody != nil && crw.ResponseBody.Len() < MaxRequestSizeLog { + crw.ResponseBody.Write(CutRequestData(data, crw.ResponseBody.Len())) + } + return crw.ResponseWriter.Write(data) +} + +func (crw *CountingResponseWriter) WriteHeader(statusCode int) { + crw.StatusCode = statusCode + crw.ResponseWriter.WriteHeader(statusCode) + if !strings.HasPrefix(crw.Header().Get("Content-Type"), "application/json") { + crw.ResponseBody = nil + } +} + +func (crw *CountingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := crw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("CountingResponseWriter: %T does not implement http.Hijacker", crw.ResponseWriter) + } + crw.Hijacked = true + return hijacker.Hijack() +} + +func CutRequestData(data []byte, length int) []byte { + if len(data)+length > MaxRequestSizeLog { + return data[:MaxRequestSizeLog-length] + } + return data +} diff --git a/requestlog/route.go b/requestlog/route.go new file mode 100644 index 0000000..070ace9 --- /dev/null +++ b/requestlog/route.go @@ -0,0 +1,55 @@ +package requestlog + +import ( + "bytes" + "io" + "net/http" + "strings" +) + +type Route struct { + Path string + Method string + Handler http.HandlerFunc + + TrackHTTPMetrics func(*Route) func(*CountingResponseWriter) + + LogContent bool +} + +var _ http.Handler = (*Route)(nil) + +func (rt *Route) ServeHTTP(w http.ResponseWriter, r *http.Request) { + crw := w.(*CountingResponseWriter) + if rt.TrackHTTPMetrics != nil { + defer rt.TrackHTTPMetrics(rt)(crw) + } + if rt.LogContent { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + crw.ResponseBody = &bytes.Buffer{} + } + if strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") { + pcr := &partialCachingReader{Reader: r.Body} + crw.RequestBody = &pcr.Buffer + r.Body = pcr + } + } + rt.Handler(w, r) +} + +type partialCachingReader struct { + Reader io.ReadCloser + Buffer bytes.Buffer +} + +func (pcr *partialCachingReader) Read(p []byte) (int, error) { + n, err := pcr.Reader.Read(p) + if n > 0 { + pcr.Buffer.Write(CutRequestData(p[:n], pcr.Buffer.Len())) + } + return n, err +} + +func (pcr *partialCachingReader) Close() error { + return pcr.Reader.Close() +}